In this project, we implement a streamlined U-Net architecture using PyTorch 2.2.1. The implementation features Conv2d layers and a custom convolution layer, CustConv, designed to minimize the number of parameters.
The U-Net architecture takes an input tensor of shape [256, T, 1] and outputs a tensor of the same shape. Below is the list of all tensor dimensionalities throughout the network:
[256, T, 1] → [256, T, 4] → [128, T, 4] → [64, T, 4] → [32, T, 8] → [16, T, 8] → [8, T, 16] → [16, T, 8] → [32, T, 8] → [64, T, 4] → [128, T, 4] → [256, T, 4] → [256, T, 1]
The architecture includes two types of convolution operations:
- Standard Conv2d - Regular 2D convolution layers.
- CustConv - A custom convolution operator designed to simplify the model by reducing kernel size while maintaining the same receptive field. CustConv layer processes input data by first dividing the channels into two groups: static and dynamic. The dynamic channels are then manipulated by shifting half of them forward by one timestep and the other half backward by one timestep, while the static channels remain unchanged. This time-shifting results in a tensor that maintains the same shape as the original input. Following this, a 2D convolution is applied to the time-shifted tensor. This convolution operation helps capture temporal patterns and interactions within the data, leveraging the altered dynamic channels to enhance the model's ability to recognize temporal dependencies.
We have used matrix multiplication for implementing time shift in a channel. Let us assume we have data for a single channel in the form of
In our use case,
Now consider the following mask matrix
where the first column is all
By right multiplying
So right multiplication by
where the rightmost column is all
Right multiplying
In this project, we use the same mechanics for implementing time shifts. If a \pytorchb tensor
To create
def shift_right_mask(n, dtype):
mask = torch.roll(torch.eye(n, dtype=dtype), shifts=[1], dims=[1])
mask[:, 0] = 0
return mask
def shift_left_mask(n, dtype):
mask = torch.roll(torch.eye(n, dtype=dtype), shifts=[-1], dims=[1])
mask[:, -1] = 0
return mask
which shifts identity matrices to either left or right and then set the "rolled" column to zero.
Their work is then aggregated to create the final mask:
def shift_mask (n_channels, T, shift_left_idxs, shift_right_idxs, dtype):
mask = torch.stack([torch.eye(T, dtype = dtype) for _ in range(n_channels)])
mask[shift_left_idxs] = shift_left_mask(T, dtype)
mask[shift_right_idxs] = shift_right_mask(T, dtype)
return mask
where a mask of all identity matrices gets created, and at the indeces where we want to shift left or right, we plant