Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Distributed Training and Added Inference Module #26

Merged
merged 24 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1517086
Now Returns Best Model Stats (Epoch Num, Train Loss, Val Loss)
iSiddharth20 Jan 7, 2024
5dc78f9
Displays Best Model Stats (Epoch Num, Train Loss, Val Loss) for all m…
iSiddharth20 Jan 7, 2024
a202cba
Added Stats Display for Method-4 Models
iSiddharth20 Jan 7, 2024
4304e30
Added DistributedSampler for PyTorch DDP
iSiddharth20 Jan 7, 2024
7206f7d
Made SSIM Regularization Faster
iSiddharth20 Jan 7, 2024
b81dc09
Cleaner Stats Display at the End
iSiddharth20 Jan 7, 2024
22eb9c7
Added Learning Rate Scheduler to Model Training
iSiddharth20 Jan 8, 2024
a164a85
Display Current Learning Rate along with Epoch Number and Losses
iSiddharth20 Jan 8, 2024
d1c1495
Rounded Off Loss Values to 10 Decimal Places
iSiddharth20 Jan 8, 2024
45769a0
Corrected lr-scheduler step position
iSiddharth20 Jan 8, 2024
f1d7006
Sync Code Across Devices
iSiddharth20 Jan 8, 2024
c5b9c67
Corrected initialization of optimizers in trainer objects
iSiddharth20 Jan 8, 2024
5529f96
Fixed Parameters for LR Scheduler
iSiddharth20 Jan 8, 2024
20742fb
Experimenting with SGD Optimizer
iSiddharth20 Jan 8, 2024
f877888
Updated Optimizers, LR Schedulers for Each Method
iSiddharth20 Jan 8, 2024
215faf4
Adam for AutoEncoder, SGD with Momentum for LSTM
iSiddharth20 Jan 8, 2024
ac4560a
First Attempt to Generate Desired Results
iSiddharth20 Jan 11, 2024
9444de9
Works, but uses too much CPU and RAM
iSiddharth20 Jan 11, 2024
bc6af54
Added CUDA Support
iSiddharth20 Jan 11, 2024
c87ae15
Inference Module
iSiddharth20 Jan 11, 2024
77c87d8
Changed Hidden Layers, Activation Function, Batch Normalization
iSiddharth20 Jan 11, 2024
92439a3
Changed Batch Size for both models
iSiddharth20 Jan 11, 2024
86b7785
Changed Hidden Layers
iSiddharth20 Jan 11, 2024
0ba0739
Changes Epochs
iSiddharth20 Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions Code/GenerateResults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
'''
Generate Results from Trained Models
'''

# Import Necessary Libraries
import platform
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.multiprocessing import Process
from PIL import Image
from torchvision import transforms
import glob
import shutil

# Import Model Definations
from autoencoder_model import Grey2RGBAutoEncoder
from lstm_model import ConvLSTM

# Define Universal Variables
image_width = 1280
image_height = 720

# Define Backend for Distributed Computing
def get_backend():
system_type = platform.system()
if system_type == "Linux":
return "nccl"
else:
return "gloo"

# Function to initialize the process group for distributed computing
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend=get_backend(), rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

# Function to clean up the process group after computation
def cleanup():
dist.destroy_process_group()

# The function to load your models
def load_model(model, model_path, device):
map_location = lambda storage, loc: storage.cuda(device)
state_dict = torch.load(model_path, map_location=map_location)
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)
# Move the model to the device and wrap the model with DDP after its state_dict has been loaded
model = model.to(device)
model = DDP(model, device_ids=[device])
return model

# Define the function to save images
def save_images(img_seq, img_dir, global_start_idx):
to_pil = transforms.ToPILImage()
for i, image_tensor in enumerate(img_seq):
global_idx = global_start_idx + i # Calculate the global index
image = to_pil(image_tensor.cpu())
image.save(f'{img_dir}/image_{global_idx:04d}.tif')

def reorder_and_save_images(img_exp_dir, output_dir):
image_paths = glob.glob(os.path.join(img_exp_dir, 'image_*.tif'))
sorted_image_paths = sorted(image_paths, key=lambda x: int(os.path.basename(x).split('_')[1].split('.')[0]))
for i, img_path in enumerate(sorted_image_paths):
img = Image.open(img_path)
img.save(os.path.join(output_dir, f'enhanced_sequence_{i:04d}.tif'))

# Define the Transformation
transform = transforms.Compose([
transforms.Resize((image_height, image_width)),
transforms.Grayscale(), # Convert the images to grayscale
transforms.ToTensor(),
])

# The main function that will be executed by each process
def enhance(rank, world_size, img_inp_dir, img_exp_dir, lstm_path, autoencoder_path):
setup(rank, world_size)
lstm_model = ConvLSTM(input_dim=1, hidden_dims=[1, 1, 1], kernel_size=(3, 3), num_layers=3, alpha=0.6)
lstm = load_model(lstm_model, lstm_path, rank)
lstm.eval()
autoencoder_model = Grey2RGBAutoEncoder()
autoencoder = load_model(autoencoder_model, autoencoder_path, rank)
autoencoder.eval()
image_files = os.listdir(img_inp_dir)
per_gpu = (len(image_files) + world_size - 1) // world_size
start_idx = rank * per_gpu
end_idx = min(start_idx + per_gpu, len(image_files))
global_start_idx = start_idx
local_images = [Image.open(os.path.join(img_inp_dir, image_files[i])) for i in range(start_idx, end_idx)]
local_tensors = torch.stack([transform(image) for image in local_images]).unsqueeze(0).to(rank)
with torch.no_grad():
local_output_sequence, _ = lstm(local_tensors)
local_output_sequence = local_output_sequence.squeeze(0)
# Interleave the input and output images
interleaved_sequence = torch.stack([t for pair in zip(local_tensors.squeeze(0), local_output_sequence) for t in pair])
with torch.no_grad():
local_output_enhanced = torch.stack([autoencoder(t.unsqueeze(0)) for t in interleaved_sequence]).squeeze(1)
save_images(local_output_enhanced, img_exp_dir, global_start_idx)
cleanup()


if __name__ == "__main__":
world_size = torch.cuda.device_count()
# Input Sequence Directory (All Methods)
img_sequence_inp_dir = r'../Dataset/Inference/InputSequence'
# Intermediate Results will be Stored in this Directory which later wll be re-ordered (All Methods)
temp_dir = r'../Dataset/Inference/OutputSequence/Temp'
os.makedirs(temp_dir, exist_ok=True)

'''Working Directories for (Method-1)'''
autoencoder_path = r'../Models/Method1/model_autoencoder_m1.pth'
lstm_path = r'../Models/Method1/model_lstm_m1.pth'
img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method1/'
os.makedirs(img_sequence_out_dir, exist_ok=True)

'''Working Directories for (Method-2)'''
# autoencoder_path = r'../Models/Method2/model_autoencoder_m2.pth'
# lstm_path = r'../Models/Method1/model_lstm_m1.pth'
# img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method2/'
# os.makedirs(img_sequence_out_dir, exist_ok=True)

'''Working Directories for (Method-3)'''
# autoencoder_path = r'../Models/Method1/model_autoencoder_m1.pth'
# lstm_path = r'../Models/Method3/model_lstm_m3.pth'
# img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method3/'
# os.makedirs(img_sequence_out_dir, exist_ok=True)

'''Working Directories for (Method-4)'''
# autoencoder_path = r'../Models/Method2/model_autoencoder_m2.pth'
# lstm_path = r'../Models/Method3/model_lstm_m3.pth'
# img_sequence_out_dir = r'../Dataset/Inference/OutputSequence/Method4/'
# os.makedirs(img_sequence_out_dir, exist_ok=True)

processes = []
for rank in range(world_size):
p = Process(target=enhance, args=(rank, world_size, img_sequence_inp_dir, temp_dir, lstm_path, autoencoder_path))
p.start()
processes.append(p)
for p in processes:
p.join()

# Reorder images once processing by all GPUs is complete
reorder_and_save_images(temp_dir, img_sequence_out_dir)
# Delete all Intermediate Results
shutil.rmtree(temp_dir)

10 changes: 4 additions & 6 deletions Code/autoencoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,20 @@ class Grey2RGBAutoEncoder(nn.Module):
def __init__(self):
super(Grey2RGBAutoEncoder, self).__init__()
# Define the Encoder
self.encoder = self._make_layers([1, 3, 6, 12, 24])
self.encoder = self._make_layers([1, 4, 8, 16, 32])
# Define the Decoder
self.decoder = self._make_layers([24, 12, 6, 3], decoder=True)
self.decoder = self._make_layers([32, 16, 8, 4, 3], decoder=True)

# Helper function to create the encoder or decoder layers.
def _make_layers(self, channels, decoder=False):
layers = []
for i in range(len(channels) - 1):
if decoder:
layers += [nn.ConvTranspose2d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(channels[i+1]),
nn.LeakyReLU(inplace=True)]
nn.ReLU(inplace=True)]
else:
layers += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(channels[i+1]),
nn.LeakyReLU(inplace=True)]
nn.ReLU(inplace=True)]
if decoder:
layers[-1] = nn.Sigmoid()
return nn.Sequential(*layers)
Expand Down
14 changes: 9 additions & 5 deletions Code/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as transforms
import torch
import os
from torch.utils.data.distributed import DistributedSampler

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
Expand Down Expand Up @@ -66,8 +66,10 @@ def get_autoencoder_batches(self, val_split, batch_size):
# Split the dataset into training and validation sets
train_dataset, val_dataset = random_split(self, [train_size, val_size])
# Create dataloaders for the training and validation sets
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size, shuffle=True)
train_sampler = DistributedSampler(train_dataset)
val_sampler = DistributedSampler(val_dataset)
train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, sampler=val_sampler)
# Return the training and validation dataloaders
return train_loader, val_loader

Expand All @@ -92,8 +94,10 @@ def get_lstm_batches(self, val_split, sequence_length, batch_size):
train_dataset.append((sequence_input_train, sequence_target_train))
val_dataset.append((sequence_input_val, sequence_target_val))
# Create the data loaders for training and validation datasets
train_loader = DataLoader(train_dataset, batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size, shuffle=False)
train_sampler = DistributedSampler(train_dataset)
val_sampler = DistributedSampler(val_dataset)
train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, sampler=val_sampler)
return train_loader, val_loader

def transform_sequence(self, filenames, lstm=False):
Expand Down
15 changes: 6 additions & 9 deletions Code/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Class for Composite Loss with Maximum Entropy Principle Regularization Term
'''
class LossMEP(nn.Module):
def __init__(self, alpha=0.5):
def __init__(self, alpha=0.1):
super(LossMEP, self).__init__()
self.alpha = alpha # Weighting factor for total variation loss

Expand All @@ -28,9 +28,9 @@ def forward(self, output, target):
torch.sum(torch.abs(output[:, :, :, :-1] - output[:, :, :, 1:]))
tv_loss /= batch_size * height * width # Normalize by total size
# Composite loss
loss = mse_loss + self.alpha * tv_loss
combined_loss = (1 - self.alpha) * mse_loss + self.alpha * tv_loss
# Return the composite loss
return loss
return combined_loss

'''
Class for Mean Squared Error (MSE) Loss
Expand All @@ -45,24 +45,21 @@ def forward(self, output, target):
- In PyTorch, loss is minimized, by doing 1 - SSIM, minimizing the loss function will lead to maximization of SSIM
'''
class SSIMLoss(nn.Module):
def __init__(self, alpha=0.5):
def __init__(self, alpha=0.1):
super(SSIMLoss, self).__init__()
self.alpha = alpha
self.ssim_module = SSIM(data_range=1, size_average=True, channel=1)

def forward(self, seq1, seq2):
N, T = seq1.shape[:2]
ssim_values = []
mse_values = []
for i in range(N):
for t in range(T):
seq1_slice = seq1[i, t:t+1, ...]
seq2_slice = seq2[i, t:t+1, ...]
ssim_val = self.ssim_module(seq1_slice, seq2_slice)
mse_val = F.mse_loss(seq1_slice, seq2_slice)
ssim_values.append(ssim_val) # Compute SSIM for each frame in the sequence
mse_values.append(mse_val) # Compute MSE for each frame in the sequence
avg_ssim = torch.stack(ssim_values).mean() # Average SSIM across all frames
avg_mse = torch.stack(mse_values).mean() # Average MSE across all frames
combined_loss = (1 - self.alpha) * avg_mse + self.alpha * (1 - avg_ssim) # SSIM is maximized, so we subtract from 1
mse_loss = F.mse_loss(seq1, seq2)
combined_loss = (1 - self.alpha) * mse_loss + self.alpha * (1 - avg_ssim) # SSIM is maximized, so we subtract from 1
return combined_loss
Loading