Skip to content

Commit

Permalink
Merge pull request #18 from iSiddharth20/Dev
Browse files Browse the repository at this point in the history
AutoEncoder Works, LSTM Doesn't Work, Issue with MaxEnt Calculation
  • Loading branch information
iSiddharth20 authored Dec 27, 2023
2 parents bb07876 + abc695b commit 90e3b04
Show file tree
Hide file tree
Showing 38 changed files with 282 additions and 235 deletions.
43 changes: 22 additions & 21 deletions Code/autoencoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,39 @@
class Grey2RGBAutoEncoder(nn.Module):
def __init__(self):
super(Grey2RGBAutoEncoder, self).__init__()

'''
# Define the Encoder
The Encoder consists of 4 Convolutional layers with ReLU activation function
Encoder takes 1-Chanel Grayscale image (1 channel) as input and outputs High-Dimentional-Representation
'''
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU()
)
self.encoder = self._make_layers([1, 64, 128, 256, 512])

'''
# Define the Decoder
The Decoder consists of 4 Transpose Convolutional layers with ReLU activation function
Decoder takes High-Dimentional-Representation as input and outputs 3-Chanel RGB image
The last layer uses a Sigmoid activation function instead of ReLU
'''
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=1, padding=1, output_padding=0),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1, output_padding=0),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1, output_padding=0),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1, output_padding=0),
nn.Sigmoid()
)
self.decoder = self._make_layers([512, 256, 128, 64, 3], decoder=True)

# Helper function to create the encoder or decoder layers.
def _make_layers(self, channels, decoder=False):
layers = []
for i in range(len(channels) - 1):
'''
For each pair of consecutive values in the channels list, a Convolutional or Transposed Convolutional layer is created.
The number of input channels is the first value, and the number of output channels is the second value.
A ReLU activation function is added after each Convolutional layer.
'''
if decoder:
layers += [nn.ConvTranspose2d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True)]
else:
layers += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True)]
if decoder:
layers[-1] = nn.Sigmoid() # Replace last ReLU with Sigmoid for decoder
return nn.Sequential(*layers)

# The forward pass takes an input image, passes it through the encoder and decoder, and returns the output image
def forward(self, x):
Expand Down
157 changes: 99 additions & 58 deletions Code/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,76 +5,117 @@
'''

# Import Necessary Libraries
import os
import numpy as np
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from pathlib import Path
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision.transforms as transforms
import torch
from torch.utils.data import DataLoader, TensorDataset

class Dataset:
def __init__(self, grayscale_dir, rgb_dir, image_size, batch_size):
self.grayscale_dir = grayscale_dir # Directory for grayscale images
self.rgb_dir = rgb_dir # Directory for RGB images
# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Define a custom dataset class
class CustomDataset(Dataset):
def __init__(self, grayscale_dir, rgb_dir, image_size, batch_size, valid_exts=['.tif', '.tiff']):
# Initialize directory paths and parameters
self.grayscale_dir = Path(grayscale_dir) # Directory for grayscale images
self.rgb_dir = Path(rgb_dir) # Directory for RGB images
self.image_size = image_size # Size to which images will be resized
self.batch_size = batch_size # Batch size for training
'''
Load Greyscale and RGB images from respective directories
Store All Images of the Directory in a Normalized NumPy arrays
Convert the NumPy arrays to PyTorch Tensors
'''
self.grayscale_images = self.load_images_to_tensor(self.grayscale_dir)
self.rgb_images = self.load_images_to_tensor(self.rgb_dir)

# Function to load images, resize and export as NumPy array
def load_images_to_tensor(self, directory):
images = []
# Loop through all files in the directory
for filename in os.listdir(directory):
# If the file is an image file
if filename.endswith('.tif') or filename.endswith('.tiff'):
img_path = os.path.join(directory, filename)
img = Image.open(img_path)
# Resize the image
img = img.resize(self.image_size)
# Append the normalized image to the list
img_array = np.array(img).astype('float32') / 255.0
# Add an extra dimension for grayscale images
if len(img_array.shape) == 2:
img_array = np.expand_dims(img_array, axis=-1)
images.append(img_array)
# Return the PyTorch Tensor {with shape [m, C, H, W]} created from of NumPy Array of Images
images = np.array(images)
images = torch.tensor(images, dtype=torch.float32).permute(0, 3, 1, 2)
return images

# Function to get batches of input-target pairs from data (This Functionality is for AutoEncoder Component of the Program)
def get_autoencoder_batches(self,val_split):
# Create a Dataset from the Tensors
dataset = TensorDataset(self.grayscale_images, self.rgb_images)
self.valid_exts = valid_exts # Valid file extensions
# Get list of valid image filenames
self.filenames = [f for f in self.grayscale_dir.iterdir() if f.suffix in self.valid_exts]
# Define transformations: resize and convert to tensor
self.transform = transforms.Compose([
transforms.Resize(self.image_size),
transforms.ToTensor()])

# Return the total number of images
def __len__(self):
return len(self.filenames)

# Get a single item or a slice from the dataset
def __getitem__(self, idx):
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# Get paths for grayscale and RGB images
grayscale_path = self.filenames[idx]
rgb_path = self.rgb_dir / grayscale_path.name
# Open images
grayscale_img = Image.open(grayscale_path)
rgb_img = Image.open(rgb_path)
# Apply transformations
grayscale_img = self.transform(grayscale_img)
rgb_img = self.transform(rgb_img)
# Return transformed images
return grayscale_img, rgb_img

# Get batches for autoencoder training
def get_autoencoder_batches(self, val_split):
# Calculate the number of samples to include in the validation set
val_size = int(val_split * len(dataset))
train_size = len(dataset) - val_size
val_size = int(val_split * len(self))
train_size = len(self) - val_size
# Split the dataset into training and validation sets
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_dataset, val_dataset = random_split(self, [train_size, val_size])
# Create dataloaders for the training and validation sets
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
print("Sample from autoencoder training data:")
for sample in train_loader:
print(f'Input shape: {sample[0].shape}, Target shape: {sample[1].shape}')
break # Just print the first sample and break
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)
# Return the training and validation dataloaders
return train_loader, val_loader

# Function to get batches of original_sequence-interpolated_sequence from data (This Functionality is for LSTM Component of the Program)
def get_lstm_batches(self):
# Add an extra dimension to the grayscale images tensor
greyscale_image_sequence = self.grayscale_images.unsqueeze(0)
# Split the sequence into training and validation sets
greyscale_image_sequence_train = greyscale_image_sequence[:, 1::2] # All odd-indexed images for Training
greyscale_image_sequence_val = greyscale_image_sequence # All images for Validation of Interpolated Frames
# Create TensorDatasets
train_data = TensorDataset(greyscale_image_sequence_train)
val_data = TensorDataset(greyscale_image_sequence_val)
# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=self.batch_size, shuffle=True)
# Get batches for LSTM training
def get_lstm_batches(self, val_split, n=1):
# Calculate the number of samples to include in the validation set
val_size = int(val_split * (len(self) // 2)) # Half of sequences because we use every second image.
train_size = (len(self) // 2) - val_size

# Get indices for the odd (input) and even (target) frames.
odd_indices = list(range(0, len(self), 2))
even_indices = list(range(1, len(self), 2))

# Split the dataset indices into training and validation subsets
train_odd_indices = odd_indices[:train_size]
val_odd_indices = odd_indices[train_size:]

train_even_indices = even_indices[:train_size]
val_even_indices = even_indices[train_size:]

# Define a helper function to extract sequences by indices
def extract_sequences(indices):
return [self[i][0] for i in indices] # Only return the grayscale images, not the tuples

# Use the helper function to create training and validation sets
train_input_seqs = torch.stack(extract_sequences(train_odd_indices))
train_target_seqs = torch.stack(extract_sequences(train_even_indices))

val_input_seqs = torch.stack(extract_sequences(val_odd_indices))
val_target_seqs = torch.stack(extract_sequences(val_even_indices))

# Create custom Dataset for the LSTM sequences
class LSTMDataset(Dataset):
def __init__(self, input_seqs, target_seqs):
self.input_seqs = input_seqs
self.target_seqs = target_seqs

def __len__(self):
return len(self.input_seqs)

def __getitem__(self, idx):
return self.input_seqs[idx], self.target_seqs[idx]

# Instantiate the custom Dataset objects
train_dataset = LSTMDataset(train_input_seqs, train_target_seqs)
val_dataset = LSTMDataset(val_input_seqs, val_target_seqs)

# Create DataLoaders for the LSTM datasets
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)

# Return the training and validation DataLoaders
return train_loader, val_loader

24 changes: 11 additions & 13 deletions Code/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Import Necessary Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_msssim import SSIM

'''
Expand All @@ -18,25 +19,24 @@ class LossMEP(nn.Module):
def __init__(self, alpha=0.5):
super(LossMEP, self).__init__()
self.alpha = alpha # Weighting factor for the loss
self.mse = nn.MSELoss() # Mean Squared Error loss

def forward(self, output, target):
mse_loss = self.mse(output, target) # Compute MSE Loss
entropy = -torch.sum(target * torch.log(output + 1e-8), dim=-1).mean() # Compute Entropy
composite_loss = self.alpha * mse_loss + (1 - self.alpha) * entropy # Compute Composite Loss
mse_loss = F.mse_loss(output, target) # Compute MSE Loss using functional API
# Normalize the output tensor along the last dimension to represent probabilities
output_normalized = torch.softmax(output, dim=-1)
# Compute Entropy
entropy = -torch.sum(target * torch.log(output_normalized + 1e-8), dim=-1).mean()
# Compute Composite Loss
composite_loss = self.alpha * mse_loss + (1 - self.alpha) * entropy
return composite_loss

'''
Class for Mean Squared Error (MSE) Loss
- Maximum Likelihood Principle
'''
class LossMSE(nn.Module):
def __init__(self):
super(LossMSE, self).__init__()
self.mse = nn.MSELoss() # Mean Squared Error loss

def forward(self, output, target):
likelihood_loss = self.mse(output, target) # Compute MSE loss
likelihood_loss = F.mse_loss(output, target) # Compute MSE loss using functional API
return likelihood_loss

'''
Expand All @@ -47,11 +47,9 @@ def forward(self, output, target):
class SSIMLoss(nn.Module):
def __init__(self, data_range=1, size_average=True):
super(SSIMLoss, self).__init__()
self.data_range = data_range # The range of the input image (usually 1.0 or 255)
self.size_average = size_average # If True, the SSIM of all windows are averaged
# Initialize SSIM module
self.ssim_module = SSIM(data_range=self.data_range, size_average=self.size_average)
self.ssim_module = SSIM(data_range=data_range, size_average=size_average)

def forward(self, img1, img2):
ssim_value = self.ssim_module(img1, img2) # Compute SSIM
return 1 - ssim_value # Return loss
return 1 - ssim_value # Return loss
Loading

0 comments on commit 90e3b04

Please sign in to comment.