From 5ff2bf5fa6ee40574d608c6479d52c3df9e51ed9 Mon Sep 17 00:00:00 2001 From: NevermindNilas Date: Fri, 14 Jun 2024 21:50:17 +0300 Subject: [PATCH] add pervfi interpolation --- README.md | 1 + main.py | 2 + requirements-linux.txt | 1 + requirements-windows.txt | 1 + src/downloadModels.py | 12 + src/initializeModels.py | 16 + src/pervfiarches/flow_estimators/__init__.py | 166 ++++ src/pervfiarches/flow_estimators/gma/corr.py | 106 +++ .../flow_estimators/gma/extractor.py | 189 +++++ src/pervfiarches/flow_estimators/gma/gma.py | 123 +++ .../flow_estimators/gma/network.py | 143 ++++ .../flow_estimators/gma/update.py | 157 ++++ .../flow_estimators/gma/utils/__init__.py | 0 .../flow_estimators/gma/utils/augmentor.py | 246 ++++++ .../flow_estimators/gma/utils/flow_viz.py | 132 +++ .../flow_estimators/gma/utils/frame_utils.py | 137 +++ .../flow_estimators/gma/utils/utils.py | 191 +++++ .../flow_estimators/gmflow/__init__.py | 0 .../flow_estimators/gmflow/backbone.py | 117 +++ .../flow_estimators/gmflow/geometry.py | 96 +++ .../flow_estimators/gmflow/gmflow.py | 170 ++++ .../flow_estimators/gmflow/matching.py | 83 ++ .../flow_estimators/gmflow/position.py | 46 + .../flow_estimators/gmflow/transformer.py | 409 +++++++++ .../flow_estimators/gmflow/trident_conv.py | 90 ++ .../flow_estimators/gmflow/utils.py | 101 +++ .../flow_estimators/raft/__init__.py | 0 src/pervfiarches/flow_estimators/raft/corr.py | 91 ++ .../flow_estimators/raft/extractor.py | 267 ++++++ src/pervfiarches/flow_estimators/raft/raft.py | 161 ++++ .../flow_estimators/raft/update.py | 139 +++ .../flow_estimators/raft/utils/__init__.py | 0 .../flow_estimators/raft/utils/augmentor.py | 246 ++++++ .../flow_estimators/raft/utils/flow_viz.py | 132 +++ .../flow_estimators/raft/utils/frame_utils.py | 137 +++ .../flow_estimators/raft/utils/utils.py | 82 ++ src/pervfiarches/generators/PFlowVFI_V0.py | 367 ++++++++ src/pervfiarches/generators/PFlowVFI_V2.py | 383 +++++++++ src/pervfiarches/generators/PFlowVFI_Vb.py | 316 +++++++ .../generators/PFlowVFI_ablation.py | 331 ++++++++ .../generators/PFlowVFI_adaptive.py | 198 +++++ src/pervfiarches/generators/__init__.py | 40 + src/pervfiarches/generators/msfusion.py | 515 ++++++++++++ .../generators/normalizing_flow.py | 461 ++++++++++ .../generators/softsplatnet/__init__.py | 794 ++++++++++++++++++ .../generators/softsplatnet/correlation.py | 462 ++++++++++ .../generators/softsplatnet/softsplat.py | 608 ++++++++++++++ src/pervfiarches/generators/thops.py | 68 ++ src/pervfiarches/pipeline.py | 88 ++ src/unifiedInterpolate.py | 156 +++- 50 files changed, 8776 insertions(+), 1 deletion(-) create mode 100644 src/pervfiarches/flow_estimators/__init__.py create mode 100644 src/pervfiarches/flow_estimators/gma/corr.py create mode 100644 src/pervfiarches/flow_estimators/gma/extractor.py create mode 100644 src/pervfiarches/flow_estimators/gma/gma.py create mode 100644 src/pervfiarches/flow_estimators/gma/network.py create mode 100644 src/pervfiarches/flow_estimators/gma/update.py create mode 100644 src/pervfiarches/flow_estimators/gma/utils/__init__.py create mode 100644 src/pervfiarches/flow_estimators/gma/utils/augmentor.py create mode 100644 src/pervfiarches/flow_estimators/gma/utils/flow_viz.py create mode 100644 src/pervfiarches/flow_estimators/gma/utils/frame_utils.py create mode 100644 src/pervfiarches/flow_estimators/gma/utils/utils.py create mode 100644 src/pervfiarches/flow_estimators/gmflow/__init__.py create mode 100644 src/pervfiarches/flow_estimators/gmflow/backbone.py create mode 100644 src/pervfiarches/flow_estimators/gmflow/geometry.py create mode 100644 src/pervfiarches/flow_estimators/gmflow/gmflow.py create mode 100644 src/pervfiarches/flow_estimators/gmflow/matching.py create mode 100644 src/pervfiarches/flow_estimators/gmflow/position.py create mode 100644 src/pervfiarches/flow_estimators/gmflow/transformer.py create mode 100644 src/pervfiarches/flow_estimators/gmflow/trident_conv.py create mode 100644 src/pervfiarches/flow_estimators/gmflow/utils.py create mode 100644 src/pervfiarches/flow_estimators/raft/__init__.py create mode 100644 src/pervfiarches/flow_estimators/raft/corr.py create mode 100644 src/pervfiarches/flow_estimators/raft/extractor.py create mode 100644 src/pervfiarches/flow_estimators/raft/raft.py create mode 100644 src/pervfiarches/flow_estimators/raft/update.py create mode 100644 src/pervfiarches/flow_estimators/raft/utils/__init__.py create mode 100644 src/pervfiarches/flow_estimators/raft/utils/augmentor.py create mode 100644 src/pervfiarches/flow_estimators/raft/utils/flow_viz.py create mode 100644 src/pervfiarches/flow_estimators/raft/utils/frame_utils.py create mode 100644 src/pervfiarches/flow_estimators/raft/utils/utils.py create mode 100644 src/pervfiarches/generators/PFlowVFI_V0.py create mode 100644 src/pervfiarches/generators/PFlowVFI_V2.py create mode 100644 src/pervfiarches/generators/PFlowVFI_Vb.py create mode 100644 src/pervfiarches/generators/PFlowVFI_ablation.py create mode 100644 src/pervfiarches/generators/PFlowVFI_adaptive.py create mode 100644 src/pervfiarches/generators/__init__.py create mode 100644 src/pervfiarches/generators/msfusion.py create mode 100644 src/pervfiarches/generators/normalizing_flow.py create mode 100644 src/pervfiarches/generators/softsplatnet/__init__.py create mode 100644 src/pervfiarches/generators/softsplatnet/correlation.py create mode 100644 src/pervfiarches/generators/softsplatnet/softsplat.py create mode 100644 src/pervfiarches/generators/thops.py create mode 100644 src/pervfiarches/pipeline.py diff --git a/README.md b/README.md index 6e4643ad..143528f6 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,7 @@ Both internal and user generated benchmarks can be found [here](BENCHMARKS.MD). | [cszn](https://github.com/cszn/DPIR) | DPIR | | [TNTWise](https://github.com/TNTwise) | For Rife ONNX and NCNN models | | [WolframRhodium](https://github.com/WolframRhodium) | For Rife V2 models | +| [mulns](https://github.com/mulns/PerVFI) | For PerVFI | ## 🌟 Star History diff --git a/main.py b/main.py index dcaa2e56..a4faef02 100644 --- a/main.py +++ b/main.py @@ -273,6 +273,8 @@ def start(self): "rife4.17-tensorrt", "rife-tensorrt", "gmfss", + "raft_pervfi_lite", + "raft_pervfi", ], default="rife", ) diff --git a/requirements-linux.txt b/requirements-linux.txt index 8ca54371..0906e749 100644 --- a/requirements-linux.txt +++ b/requirements-linux.txt @@ -12,6 +12,7 @@ GPUtil spandrel yt-dlp requests +accelerate tqdm tensorrt scikit-image diff --git a/requirements-windows.txt b/requirements-windows.txt index cd41fd50..07647bdf 100644 --- a/requirements-windows.txt +++ b/requirements-windows.txt @@ -9,6 +9,7 @@ kornia scipy wmi spandrel +accelerate yt-dlp requests tqdm diff --git a/src/downloadModels.py b/src/downloadModels.py index 2ce9b360..cec66d87 100644 --- a/src/downloadModels.py +++ b/src/downloadModels.py @@ -67,6 +67,8 @@ def modelsList() -> list[str]: "small-directml", "base-directml", "large-directml", + "pervfi_lite", + "pervfi", ] @@ -327,6 +329,16 @@ def modelsMap( return "maxxvitv2_rmlp_base_rw_224.sw_in12k_b80_224px_20k_coloraug0.4_6ch_clamp_softmax_fp16_op17_onnxslim.onnx" else: return "maxxvitv2_rmlp_base_rw_224.sw_in12k_b80_224px_20k_coloraug0.4_6ch_clamp_softmax_op17_onnxslim.onnx" + + case "pervfi_lite": + return "pervfi_lite.pth" + + case "pervfi": + return "pervfi.pth" + + case "raft": + return "raft-sintel.pth" + case _: raise ValueError(f"Model {model} not found.") diff --git a/src/initializeModels.py b/src/initializeModels.py index 93105cfe..dc1f0cd3 100644 --- a/src/initializeModels.py +++ b/src/initializeModels.py @@ -287,6 +287,22 @@ def initializeModels(self): self.scenechange, ) + case ( + "raft_pervfi_lite" + | "raft_pervfi" + ): + + from src.unifiedInterpolate import PerVFIRaftCuda + + interpolate_process = PerVFIRaftCuda( + self.interpolate_method, + self.half, + outputWidth, + outputHeight, + self.interpolate_factor, + self.scenechange, + ) + if self.denoise: match self.denoise_method: case "scunet" | "dpir" | "nafnet": diff --git a/src/pervfiarches/flow_estimators/__init__.py b/src/pervfiarches/flow_estimators/__init__.py new file mode 100644 index 00000000..728278fc --- /dev/null +++ b/src/pervfiarches/flow_estimators/__init__.py @@ -0,0 +1,166 @@ +import torch + + +class InputPadder: + """Pads images such that dimensions are divisible by factor""" + + def __init__(self, size, divide=8, mode="center"): + self.ht, self.wd = size[-2:] + pad_ht = (((self.ht // divide) + 1) * divide - self.ht) % divide + pad_wd = (((self.wd // divide) + 1) * divide - self.wd) % divide + if mode == "center": + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] + else: + self._pad = [0, pad_wd, 0, pad_ht] + + def _pad_(self, x): + return torch.nn.functional.pad(x, self._pad, mode="constant") + + def pad(self, *inputs): + return [self._pad_(x) for x in inputs] + + def _unpad_(self, x): + return x[ + ..., + self._pad[2] : self.ht + self._pad[2], + self._pad[0] : self.wd + self._pad[0], + ] + + def unpad(self, *inputs): + return [self._unpad_(x) for x in inputs] + + +def build_flow_estimator(name, device="cuda", checkpoint=None): + if name.lower() == "raft": + import argparse + + from .raft.raft import RAFT + + args = argparse.Namespace( + mixed_precision=True, alternate_corr=False, small=False + ) + model = RAFT(args) + + if checkpoint is None: + ckpt = "checkpoints/RAFT/raft-sintel.pth" + else: + ckpt = checkpoint # Path to the checkpoint .pth, modified from original arch for better comp + + model.load_state_dict( + {k.replace("module.", ""): v for k, v in torch.load(ckpt, map_location=device).items()} + ) + model.to(device).eval() + + @torch.no_grad() + def infer(I1, I2): + I1 = I1.to(device) * 255.0 + I2 = I2.to(device) * 255.0 + padder = InputPadder(I1.shape, 8) + I1, I2 = padder.pad(I1, I2) + fflow = model(I1, I2, bidirection=False, iters=12) + bflow = model(I2, I1, bidirection=False, iters=12) + return padder.unpad(fflow, bflow) + + if name.lower() == "raft_small": + import argparse + + from .raft.raft import RAFT + + args = argparse.Namespace( + mixed_precision=True, alternate_corr=False, small=True + ) + model = RAFT(args) + ckpt = "checkpoints/RAFT/raft-small.pth" + model.load_state_dict( + {k.replace("module.", ""): v for k, v in torch.load(ckpt).items()} + ) + model.to(device).eval() + + @torch.no_grad() + def infer(I1, I2): + I1 = I1.to(device) * 255.0 + I2 = I2.to(device) * 255.0 + padder = InputPadder(I1.shape, 8) + I1, I2 = padder.pad(I1, I2) + fflow = model(I1, I2, bidirection=False, iters=12) + bflow = model(I2, I1, bidirection=False, iters=12) + return padder.unpad(fflow, bflow) + + if name.lower() == "gma": + import argparse + + from .gma.network import RAFTGMA + + args = argparse.Namespace( + mixed_precision=True, + num_heads=1, + position_only=False, + position_and_content=False, + ) + model = RAFTGMA(args) + ckpt = "checkpoints/GMA/gma-sintel.pth" + model.load_state_dict( + {k.replace("module.", ""): v for k, v in torch.load(ckpt).items()} + ) + model.to(device).eval() + + @torch.no_grad() + def infer(I1, I2): + I1 = I1.to(device) * 255.0 + I2 = I2.to(device) * 255.0 + padder = InputPadder(I1.shape, 8) + I1, I2 = padder.pad(I1, I2) + _, fflow = model(I1, I2, test_mode=True, iters=20) + _, bflow = model(I2, I1, test_mode=True, iters=20) + return padder.unpad(fflow, bflow) + + if name.lower() == "gmflow": + from .gmflow.gmflow import GMFlow + + model = GMFlow( + feature_channels=128, + num_scales=1, + upsample_factor=8, + num_head=1, + attention_type="swin", + ffn_dim_expansion=4, + num_transformer_layers=6, + ) + ckpt = "checkpoints/GMFlow/gmflow_sintel-0c07dcb3.pth" + model.load_state_dict(torch.load(ckpt)["model"]) + model.to(device).eval() + + @torch.no_grad() + def infer(I1, I2): + I1 = I1.to(device) * 255.0 + I2 = I2.to(device) * 255.0 + padder = InputPadder(I1.shape, 16) + I1, I2 = padder.pad(I1, I2) + results_dict = model( + I1, + I2, + attn_splits_list=[2], + corr_radius_list=[-1], + prop_radius_list=[-1], + pred_bidir_flow=False, + ) + fflow = results_dict["flow_preds"][-1] + results_dict = model( + I2, + I1, + attn_splits_list=[2], + corr_radius_list=[-1], + prop_radius_list=[-1], + pred_bidir_flow=False, + ) + bflow = results_dict["flow_preds"][-1] + + fflow, bflow = padder.unpad(fflow, bflow) + return fflow, bflow + + return model, infer diff --git a/src/pervfiarches/flow_estimators/gma/corr.py b/src/pervfiarches/flow_estimators/gma/corr.py new file mode 100644 index 00000000..7d22fe09 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gma/corr.py @@ -0,0 +1,106 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils.utils import bilinear_sampler, coords_grid + +# from compute_sparse_correlation import compute_sparse_corr, compute_sparse_corr_torch, compute_sparse_corr_mink + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class CorrBlockSingleScale(nn.Module): + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + super().__init__() + self.radius = radius + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + self.corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + corr = self.corr + dx = torch.linspace(-r, r, 2 * r + 1) + dy = torch.linspace(-r, r, 2 * r + 1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + out = corr.view(batch, h1, w1, -1) + out = out.permute(0, 3, 1, 2).contiguous().float() + return out + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/src/pervfiarches/flow_estimators/gma/extractor.py b/src/pervfiarches/flow_estimators/gma/extractor.py new file mode 100644 index 00000000..54f17833 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gma/extractor.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/src/pervfiarches/flow_estimators/gma/gma.py b/src/pervfiarches/flow_estimators/gma/gma.py new file mode 100644 index 00000000..c1c84492 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gma/gma.py @@ -0,0 +1,123 @@ +import torch +from torch import nn, einsum +from einops import rearrange + + +class RelPosEmb(nn.Module): + def __init__( + self, + max_pos_size, + dim_head + ): + super().__init__() + self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) + self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) + + deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) + rel_ind = deltas + max_pos_size - 1 + self.register_buffer('rel_ind', rel_ind) + + def forward(self, q): + batch, heads, h, w, c = q.shape + height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) + width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) + + height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) + width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) + + height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) + width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) + + return height_score + width_score + + +class Attention(nn.Module): + def __init__( + self, + *, + args, + dim, + max_pos_size = 100, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) + + self.pos_emb = RelPosEmb(max_pos_size, dim_head) + + def forward(self, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + q, k = self.to_qk(fmap).chunk(2, dim=1) + + q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) + q = self.scale * q + + if self.args.position_only: + sim = self.pos_emb(q) + + elif self.args.position_and_content: + sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + sim_pos = self.pos_emb(q) + sim = sim_content + sim_pos + + else: + sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + + sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') + attn = sim.softmax(dim=-1) + + return attn + + +class Aggregate(nn.Module): + def __init__( + self, + args, + dim, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) + + self.gamma = nn.Parameter(torch.zeros(1)) + + if dim != inner_dim: + self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) + else: + self.project = None + + def forward(self, attn, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + v = self.to_v(fmap) + v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) + + if self.project is not None: + out = self.project(out) + + out = fmap + self.gamma * out + + return out + + +if __name__ == "__main__": + att = Attention(dim=128, heads=1) + fmap = torch.randn(2, 128, 40, 90) + out = att(fmap) + + print(out.shape) diff --git a/src/pervfiarches/flow_estimators/gma/network.py b/src/pervfiarches/flow_estimators/gma/network.py new file mode 100644 index 00000000..4dcf5ebb --- /dev/null +++ b/src/pervfiarches/flow_estimators/gma/network.py @@ -0,0 +1,143 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .corr import CorrBlock +from .extractor import BasicEncoder +from .gma import Aggregate, Attention +from .update import GMAUpdateBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + +class RAFTGMA(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + + if "dropout" not in self.args: + self.args.dropout = 0 + + # feature network, context network, and update block + self.fnet = BasicEncoder( + output_dim=256, norm_fn="instance", dropout=args.dropout + ) + self.cnet = BasicEncoder( + output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout + ) + self.update_block = GMAUpdateBlock(self.args, hidden_dim=hdim) + self.att = Attention( + args=self.args, + dim=cdim, + heads=self.args.num_heads, + max_pos_size=160, + dim_head=cdim, + ) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H // 8, W // 8).to(img.device) + coords1 = coords_grid(N, H // 8, W // 8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def forward( + self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False + ): + """Estimate optical flow between pair of frames""" + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + # attention, att_c, att_p = self.att(inp) + attention = self.att(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block( + net, inp, corr, flow, attention + ) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + return flow_predictions diff --git a/src/pervfiarches/flow_estimators/gma/update.py b/src/pervfiarches/flow_estimators/gma/update.py new file mode 100644 index 00000000..2fa73b94 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gma/update.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .gma import Aggregate + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=128 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow + + +class GMAUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128): + super().__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU( + hidden_dim=hidden_dim, input_dim=128 + hidden_dim + hidden_dim + ) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) + + self.aggregator = Aggregate( + args=self.args, dim=128, dim_head=128, heads=self.args.num_heads + ) + + def forward(self, net, inp, corr, flow, attention): + motion_features = self.encoder(flow, corr) + motion_features_global = self.aggregator(attention, motion_features) + inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) + + # Attentional update + net = self.gru(net, inp_cat) + + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = 0.25 * self.mask(net) + return net, mask, delta_flow diff --git a/src/pervfiarches/flow_estimators/gma/utils/__init__.py b/src/pervfiarches/flow_estimators/gma/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pervfiarches/flow_estimators/gma/utils/augmentor.py b/src/pervfiarches/flow_estimators/gma/utils/augmentor.py new file mode 100644 index 00000000..f73bab60 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gma/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/src/pervfiarches/flow_estimators/gma/utils/flow_viz.py b/src/pervfiarches/flow_estimators/gma/utils/flow_viz.py new file mode 100644 index 00000000..dcee65e8 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gma/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/src/pervfiarches/flow_estimators/gma/utils/frame_utils.py b/src/pervfiarches/flow_estimators/gma/utils/frame_utils.py new file mode 100644 index 00000000..6c491135 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gma/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/src/pervfiarches/flow_estimators/gma/utils/utils.py b/src/pervfiarches/flow_estimators/gma/utils/utils.py new file mode 100644 index 00000000..f64841b8 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gma/utils/utils.py @@ -0,0 +1,191 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate +# from torch_scatter import scatter_softmax, scatter_add + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].expand(batch, -1, -1, -1) + + +def coords_grid_y_first(batch, ht, wd): + """Place y grid before x grid""" + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords, dim=0).int() + return coords[None].expand(batch, -1, -1, -1) + + +def soft_argmax(corr_me, B, H1, W1): + # Implement soft argmin + coords, feats = corr_me.decomposed_coordinates_and_features + + # Computing soft argmin + flow_pred = torch.zeros(B, 2, H1, W1).to(corr_me.device) + for batch, (coord, feat) in enumerate(zip(coords, feats)): + coord_img_1 = coord[:, :2].to(corr_me.device) + coord_img_2 = coord[:, 2:].to(corr_me.device) + # relative positions (flow hypotheses) + rel_pos = (coord_img_2 - coord_img_1) + # augmented indices + aug_coord_img_1 = (coord_img_1[:, 0:1] * W1 + coord_img_1[:, 1:2]).long() + # run softmax on the score + weight = scatter_softmax(feat, aug_coord_img_1, dim=0) + rel_pos_weighted = weight * rel_pos + out = scatter_add(rel_pos_weighted, aug_coord_img_1, dim=0) + # Need to permute (y, x) to (x, y) for flow + flow_pred[batch] = out[:, [1,0]].view(H1, W1, 2).permute(2, 0, 1) + return flow_pred + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + +def upflow4(flow, mode='bilinear'): + new_size = (4 * flow.shape[2], 4 * flow.shape[3]) + return 4 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + +def upflow2(flow, mode='bilinear'): + new_size = (2 * flow.shape[2], 2 * flow.shape[3]) + return 2 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + +def downflow8(flow, mode='bilinear'): + new_size = (flow.shape[2] // 8, flow.shape[3] // 8) + return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) / 8 + + +def downflow4(flow, mode='bilinear'): + new_size = (flow.shape[2] // 4, flow.shape[3] // 4) + return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) / 4 + + +def compute_interpolation_weights(yx_warped): + # yx_warped: [N, 2] + y_warped = yx_warped[:, 0] + x_warped = yx_warped[:, 1] + + # elementwise operations below + y_f = torch.floor(y_warped) + y_c = y_f + 1 + x_f = torch.floor(x_warped) + x_c = x_f + 1 + + w0 = (y_c - y_warped) * (x_c - x_warped) + w1 = (y_warped - y_f) * (x_c - x_warped) + w2 = (y_c - y_warped) * (x_warped - x_f) + w3 = (y_warped - y_f) * (x_warped - x_f) + + weights = [w0, w1, w2, w3] + indices = [torch.stack([y_f, x_f], dim=1), torch.stack([y_c, x_f], dim=1), + torch.stack([y_f, x_c], dim=1), torch.stack([y_c, x_c], dim=1)] + weights = torch.cat(weights, dim=1) + indices = torch.cat(indices, dim=2) + # indices = torch.cat(indices, dim=0) # [4*N, 2] + + return weights, indices + +# weights, indices = compute_interpolation_weights(xy_warped, b, h_i, w_i) + + +def compute_inverse_interpolation_img(weights, indices, img, b, h_i, w_i): + """ + weights: [b, h*w] + indices: [b, h*w] + img: [b, h*w, a, b, c, ...] + """ + w0, w1, w2, w3 = weights + ff_idx, cf_idx, fc_idx, cc_idx = indices + + k = len(img.size()) - len(w0.size()) + img_0 = w0[(...,) + (None,) * k] * img + img_1 = w1[(...,) + (None,) * k] * img + img_2 = w2[(...,) + (None,) * k] * img + img_3 = w3[(...,) + (None,) * k] * img + + img_out = torch.zeros(b, h_i * w_i, *img.shape[2:]).type_as(img) + + ff_idx = torch.clamp(ff_idx, min=0, max=h_i * w_i - 1) + cf_idx = torch.clamp(cf_idx, min=0, max=h_i * w_i - 1) + fc_idx = torch.clamp(fc_idx, min=0, max=h_i * w_i - 1) + cc_idx = torch.clamp(cc_idx, min=0, max=h_i * w_i - 1) + + img_out.scatter_add_(1, ff_idx[(...,) + (None,) * k].expand_as(img_0), img_0) + img_out.scatter_add_(1, cf_idx[(...,) + (None,) * k].expand_as(img_1), img_1) + img_out.scatter_add_(1, fc_idx[(...,) + (None,) * k].expand_as(img_2), img_2) + img_out.scatter_add_(1, cc_idx[(...,) + (None,) * k].expand_as(img_3), img_3) + + return img_out # [b, h_i*w_i, ...] diff --git a/src/pervfiarches/flow_estimators/gmflow/__init__.py b/src/pervfiarches/flow_estimators/gmflow/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pervfiarches/flow_estimators/gmflow/backbone.py b/src/pervfiarches/flow_estimators/gmflow/backbone.py new file mode 100644 index 00000000..a30942ec --- /dev/null +++ b/src/pervfiarches/flow_estimators/gmflow/backbone.py @@ -0,0 +1,117 @@ +import torch.nn as nn + +from .trident_conv import MultiScaleTridentConv + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, + ): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, stride=stride, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, bias=False) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__(self, output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 + self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer(feature_dims[2], stride=stride, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) + layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out diff --git a/src/pervfiarches/flow_estimators/gmflow/geometry.py b/src/pervfiarches/flow_estimators/gmflow/geometry.py new file mode 100644 index 00000000..207e98fd --- /dev/null +++ b/src/pervfiarches/flow_estimators/gmflow/geometry.py @@ -0,0 +1,96 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5 + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ diff --git a/src/pervfiarches/flow_estimators/gmflow/gmflow.py b/src/pervfiarches/flow_estimators/gmflow/gmflow.py new file mode 100644 index 00000000..cd413833 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gmflow/gmflow.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import CNNEncoder +from .transformer import FeatureTransformer, FeatureFlowAttention +from .matching import global_correlation_softmax, local_correlation_softmax +from .geometry import flow_warp +from .utils import normalize_img, feature_add_position + + +class GMFlow(nn.Module): + def __init__(self, + num_scales=1, + upsample_factor=8, + feature_channels=128, + attention_type='swin', + num_transformer_layers=6, + ffn_dim_expansion=4, + num_head=1, + **kwargs, + ): + super(GMFlow, self).__init__() + + self.num_scales = num_scales + self.feature_channels = feature_channels + self.upsample_factor = upsample_factor + self.attention_type = attention_type + self.num_transformer_layers = num_transformer_layers + + # CNN backbone + self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) + + # Transformer + self.transformer = FeatureTransformer(num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + ) + + # flow propagation with self-attn + self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels) + + # convex upsampling: concat feature0 and flow as input + self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0)) + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, + ): + if bilinear: + up_flow = F.interpolate(flow, scale_factor=upsample_factor, + mode='bilinear', align_corners=True) * upsample_factor + + else: + # convex upsampling + concat = torch.cat((flow, feature), dim=1) + + mask = self.upsampler(concat) + b, flow_channel, h, w = flow.shape + mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W] + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1) + up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] + + up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] + up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h, + self.upsample_factor * w) # [B, 2, K*H, K*W] + + return up_flow + + def forward(self, img0, img1, + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + pred_bidir_flow=False, + **kwargs, + ): + + results_dict = {} + flow_preds = [] + + img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] + + # resolution low to high + feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features + + flow = None + + assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales + + for scale_idx in range(self.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + if pred_bidir_flow and scale_idx > 0: + # predicting bidirectional flow with refinement + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + + upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx)) + + if scale_idx > 0: + flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 + + if flow is not None: + flow = flow.detach() + feature1 = flow_warp(feature1, flow) # [B, C, H, W] + + attn_splits = attn_splits_list[scale_idx] + corr_radius = corr_radius_list[scale_idx] + prop_radius = prop_radius_list[scale_idx] + + # add position to features + feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) + + # Transformer + feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits) + + # correlation and softmax + if corr_radius == -1: # global matching + flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0] + else: # local matching + flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] + + # flow or residual flow + flow = flow + flow_pred if flow is not None else flow_pred + + # upsample to the original resolution for supervison + if self.training: # only need to upsample intermediate flow predictions at training time + flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor) + flow_preds.append(flow_bilinear) + + # flow propagation with self-attn + if pred_bidir_flow and scale_idx == 0: + feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation + flow = self.feature_flow_attn(feature0, flow.detach(), + local_window_attn=prop_radius > 0, + local_window_radius=prop_radius) + + # bilinear upsampling at training time except the last one + if self.training and scale_idx < self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor) + flow_preds.append(flow_up) + + if scale_idx == self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0) + flow_preds.append(flow_up) + + results_dict.update({'flow_preds': flow_preds}) + + return results_dict diff --git a/src/pervfiarches/flow_estimators/gmflow/matching.py b/src/pervfiarches/flow_estimators/gmflow/matching.py new file mode 100644 index 00000000..17402009 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gmflow/matching.py @@ -0,0 +1,83 @@ +import torch +import torch.nn.functional as F + +from .geometry import coords_grid, generate_window_grid, normalize_coords + + +def global_correlation_softmax(feature0, feature1, + pred_bidir_flow=False, + ): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax(feature0, feature1, local_radius, + padding_mode='zeros', + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob diff --git a/src/pervfiarches/flow_estimators/gmflow/position.py b/src/pervfiarches/flow_estimators/gmflow/position.py new file mode 100644 index 00000000..14a6da43 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gmflow/position.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py + +import torch +import torch.nn as nn +import math + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/src/pervfiarches/flow_estimators/gmflow/transformer.py b/src/pervfiarches/flow_estimators/gmflow/transformer.py new file mode 100644 index 00000000..9a8f2ceb --- /dev/null +++ b/src/pervfiarches/flow_estimators/gmflow/transformer.py @@ -0,0 +1,409 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import split_feature, merge_splits + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, + shift_size_h, shift_size_w, device=torch.device('cuda')): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = (slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None)) + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +def single_head_split_window_attention(q, k, v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, channel_last=True) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +class TransformerLayer(nn.Module): + def __init__(self, + d_model=256, + nhead=1, + attention_type='swin', + no_ffn=False, + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.attention_type = attention_type + self.no_ffn = no_ffn + + self.with_shift = with_shift + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if self.attention_type == 'swin' and attn_num_splits > 1: + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=self.with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__(self, + d_model=256, + nhead=1, + attention_type='swin', + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + self.cross_attn_ffn = TransformerLayer(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn(source, source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn(source, target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__(self, + num_layers=6, + d_model=128, + nhead=1, + attention_type='swin', + ffn_dim_expansion=4, + **kwargs, + ): + super(FeatureTransformer, self).__init__() + + self.attention_type = attention_type + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList([ + TransformerBlock(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=True if attention_type == 'swin' and i % 2 == 1 else False, + ) + for i in range(num_layers)]) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, feature1, + attn_num_splits=None, + **kwargs, + ): + + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + if self.attention_type == 'swin' and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for layer in self.layers: + concat0 = layer(concat0, concat1, + height=h, + width=w, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + + return feature0, feature1 + + +class FeatureFlowAttention(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__(self, in_channels, + **kwargs, + ): + super(FeatureFlowAttention, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn(feature0, flow, + local_window_radius=local_window_radius) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + # a note: the ``correct'' implementation should be: + # ``query = self.q_proj(query), key = self.k_proj(query)'' + # this problem is observed while cleaning up the code + # however, this doesn't affect the performance since the projection is a linear operation, + # thus the two projection matrices for key can be merged + # so I just leave it as is in order to not re-train all models :) + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn(self, feature0, flow, + local_window_radius=1, + ): + assert flow.size(1) == 2 + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) + ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) + + feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, + padding=local_window_radius) # [B, C*(2R+1)^2), H*W] + + feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( + 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold(flow, kernel_size=kernel_size, + padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] + + flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute( + 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + return out diff --git a/src/pervfiarches/flow_estimators/gmflow/trident_conv.py b/src/pervfiarches/flow_estimators/gmflow/trident_conv.py new file mode 100644 index 00000000..29a2a73e --- /dev/null +++ b/src/pervfiarches/flow_estimators/gmflow/trident_conv.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = dilation + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], + self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs diff --git a/src/pervfiarches/flow_estimators/gmflow/utils.py b/src/pervfiarches/flow_estimators/gmflow/utils.py new file mode 100644 index 00000000..7a659e39 --- /dev/null +++ b/src/pervfiarches/flow_estimators/gmflow/utils.py @@ -0,0 +1,101 @@ +import torch + +from .position import PositionEmbeddingSine + + +def split_feature( + feature, + num_splits=2, + channel_last=False, +): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = ( + feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c) + .permute(0, 1, 3, 2, 4, 5) + .reshape(b_new, h_new, w_new, c) + ) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0, f"height: {h}, width: {w}" + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = ( + feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits) + .permute(0, 2, 4, 1, 3, 5) + .reshape(b_new, c, h_new, w_new) + ) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits( + splits, + num_splits=2, + channel_last=False, +): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = ( + splits.permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(new_b, num_splits * h, num_splits * w, c) + ) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = ( + splits.permute(0, 3, 1, 4, 2, 5) + .contiguous() + .view(new_b, c, num_splits * h, num_splits * w) + ) # [B, C, H, W] + + return merge + + +def normalize_img(img0, img1): + # loaded images are in [0, 255] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) + img0 = (img0 / 255.0 - mean) / std + img1 = (img1 / 255.0 - mean) / std + + return img0, img1 + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 diff --git a/src/pervfiarches/flow_estimators/raft/__init__.py b/src/pervfiarches/flow_estimators/raft/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pervfiarches/flow_estimators/raft/corr.py b/src/pervfiarches/flow_estimators/raft/corr.py new file mode 100644 index 00000000..7cd97cff --- /dev/null +++ b/src/pervfiarches/flow_estimators/raft/corr.py @@ -0,0 +1,91 @@ +import torch +import torch.nn.functional as F + +from .utils.utils import bilinear_sampler, coords_grid + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4, bidirection=False): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + if bidirection: + corr1 = CorrBlock.corr(fmap1, fmap2) + corr2 = corr1.permute(0, 4, 5, 3, 1, 2) + corr = torch.cat([corr1, corr2], dim=0) + else: + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/src/pervfiarches/flow_estimators/raft/extractor.py b/src/pervfiarches/flow_estimators/raft/extractor.py new file mode 100644 index 00000000..9a9c759d --- /dev/null +++ b/src/pervfiarches/flow_estimators/raft/extractor.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/src/pervfiarches/flow_estimators/raft/raft.py b/src/pervfiarches/flow_estimators/raft/raft.py new file mode 100644 index 00000000..e0b21c83 --- /dev/null +++ b/src/pervfiarches/flow_estimators/raft/raft.py @@ -0,0 +1,161 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .corr import CorrBlock +from .extractor import BasicEncoder, SmallEncoder +from .update import BasicUpdateBlock, SmallUpdateBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + + if "dropout" not in self.args: + self.args.dropout = 0 + + if "alternate_corr" not in self.args: + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder( + output_dim=128, norm_fn="instance", dropout=args.dropout + ) + self.cnet = SmallEncoder( + output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout + ) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder( + output_dim=256, norm_fn="instance", dropout=args.dropout + ) + self.cnet = BasicEncoder( + output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout + ) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H // 8, W // 8, device=img.device) + coords1 = coords_grid(N, H // 8, W // 8, device=img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def forward( + self, + image1, + image2, + iters=12, + flow_init=None, + bidirection=True, + ): + """Estimate optical flow between pair of frames""" + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + raise NotImplementedError + else: + corr_fn = CorrBlock( + fmap1, fmap2, radius=self.args.corr_radius, bidirection=bidirection + ) + if bidirection: + image1 = torch.cat([image1, image2], dim=0) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + if bidirection: + return torch.chunk(flow_up, 2, 0) + else: + return flow_up diff --git a/src/pervfiarches/flow_estimators/raft/update.py b/src/pervfiarches/flow_estimators/raft/update.py new file mode 100644 index 00000000..f940497f --- /dev/null +++ b/src/pervfiarches/flow_estimators/raft/update.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + + + diff --git a/src/pervfiarches/flow_estimators/raft/utils/__init__.py b/src/pervfiarches/flow_estimators/raft/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pervfiarches/flow_estimators/raft/utils/augmentor.py b/src/pervfiarches/flow_estimators/raft/utils/augmentor.py new file mode 100644 index 00000000..e81c4f2b --- /dev/null +++ b/src/pervfiarches/flow_estimators/raft/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/src/pervfiarches/flow_estimators/raft/utils/flow_viz.py b/src/pervfiarches/flow_estimators/raft/utils/flow_viz.py new file mode 100644 index 00000000..dcee65e8 --- /dev/null +++ b/src/pervfiarches/flow_estimators/raft/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/src/pervfiarches/flow_estimators/raft/utils/frame_utils.py b/src/pervfiarches/flow_estimators/raft/utils/frame_utils.py new file mode 100644 index 00000000..6c491135 --- /dev/null +++ b/src/pervfiarches/flow_estimators/raft/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/src/pervfiarches/flow_estimators/raft/utils/utils.py b/src/pervfiarches/flow_estimators/raft/utils/utils.py new file mode 100644 index 00000000..741ccfe4 --- /dev/null +++ b/src/pervfiarches/flow_estimators/raft/utils/utils.py @@ -0,0 +1,82 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/src/pervfiarches/generators/PFlowVFI_V0.py b/src/pervfiarches/generators/PFlowVFI_V0.py new file mode 100644 index 00000000..2863123c --- /dev/null +++ b/src/pervfiarches/generators/PFlowVFI_V0.py @@ -0,0 +1,367 @@ +"""PerVFI: Soft-binary Blending for Photo-realistic Video Frame Interpolation""" + +import accelerate +import torch +import torch.nn.functional as F + +from torch import Tensor +from torchvision.ops import DeformConv2d + +from . import thops +from .msfusion import MultiscaleFuse +from .normalizing_flow import * +from .softsplatnet import Encode, Softmetric +from .softsplatnet.softsplat import softsplat + + +def resize(x, size: tuple, scale: bool): + H, W = x.shape[-2:] + h, w = size + scale_ = h / H + x_ = F.interpolate(x, size, mode="bilinear", align_corners=False) + if scale: + return x_ * scale_ + return x_ + + +def binary_hole(flow): + n, _, h, w = flow.shape + mask = softsplat( + tenIn=torch.ones((n, 1, h, w), device=flow.device), + tenFlow=flow, + tenMetric=None, + strMode="avg", + ) + ones = torch.ones_like(mask, device=mask.device) + zeros = torch.zeros_like(mask, device=mask.device) + out = torch.where(mask <= 0.5, ones, zeros) + return out + + +def warp_pyramid(features: list, metric, flow): + outputs = [] + masks = [] + for lv in range(3): + fea = features[lv] + if lv != 0: + h, w = fea.shape[-2:] + metric = resize(metric, (h, w), scale=False) + flow = resize(flow, (h, w), scale=True) + outputs.append(softsplat(fea, flow, metric.neg().clip(-20.0, 20.0), "soft")) + masks.append(binary_hole(flow)) + return outputs, masks + + +class FeaturePyramid(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netEncode = Encode() + self.netSoftmetric = Softmetric() + + def forward( + self, + tenOne, + tenTwo=None, + tenFlows: list[Tensor] = None, + time: float = 0.5, + ): + x1s = self.netEncode(tenOne) + if tenTwo is None: # just encode + return x1s + F12, F21 = tenFlows + x2s = self.netEncode(tenTwo) + m1t = self.netSoftmetric(x1s, x2s, F12) * 2 * time + F1t = time * F12 + m2t = self.netSoftmetric(x2s, x1s, F21) * 2 * (1 - time) + F2t = (1 - time) * F21 + Ft2 = -1 * softsplat(F2t, F2t, m2t.neg().clip(-20.0, 20.0), "soft") + x1s, bmasks = warp_pyramid(x1s, m1t, F1t) + return list(zip(x1s, x2s)), bmasks, Ft2 + + +class SoftBinary(torch.nn.Module): + def __init__(self, cin, dilate_size=7) -> None: + super().__init__() + channel = 64 + reduction = 8 + self.conv1 = torch.nn.Sequential( + *[ + torch.nn.Conv2d(1, channel, dilate_size, 1, padding="same", bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(channel, channel, 3, 1, padding="same", bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(channel, channel, 1, 1, padding="same", bias=False), + ] + ) + self.att = torch.nn.Conv2d(cin * 2, channel, 3, 1, padding="same") + self.avg = torch.nn.AdaptiveAvgPool2d(1) + self.fc = torch.nn.Sequential( + torch.nn.Linear(channel, channel // reduction, bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(channel // reduction, channel, bias=False), + torch.nn.Sigmoid(), + ) + self.conv2 = torch.nn.Conv2d(channel, 1, 1, 1, padding="same", bias=False) + + def forward(self, bmask, feaL, feaR): # N,1,H,W + m_fea = self.conv1(bmask) + x = self.att(torch.cat([feaL, feaR], dim=1)) + b, c, _, _ = x.size() + x = self.avg(x).view(b, c) + x = self.fc(x).view(b, c, 1, 1) + x = m_fea * x.expand_as(x) + x = self.conv2(x) + + x = torch.tanh(torch.abs(x)) + rand_bias = (torch.rand_like(x, device=x.device) - 0.5) / 100.0 + if self.training: + return x + rand_bias + return x + + +class DCNPack(torch.nn.Module): + def __init__(self, cin, groups, dksize): + super().__init__() + cout = groups * 3 * dksize**2 + self.conv_offset = torch.nn.Conv2d(cin, cout, 3, 1, 1) + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + self.dconv = DeformConv2d(cin, cin, dksize, padding=dksize // 2) + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger.info(f"Offset abs mean is {offset_absmean}, larger than 50.") + + return self.dconv(x, offset, mask) + + +class DeformableAlign(torch.nn.Module): + def __init__(self): + super().__init__() + channels = [35, 64, 96] + self.offset_conv1 = torch.nn.ModuleDict() + self.offset_conv2 = torch.nn.ModuleDict() + self.offset_conv3 = torch.nn.ModuleDict() + self.deform_conv = torch.nn.ModuleDict() + self.feat_conv = torch.nn.ModuleDict() + self.merge_conv1 = torch.nn.ModuleDict() + self.merge_conv2 = torch.nn.ModuleDict() + # Pyramids + for i in range(2, -1, -1): + level = f"l{i}" + c = channels[i] + # compute offsets + self.offset_conv1[level] = torch.nn.Conv2d(c * 2 + 3, c, 3, 1, 1) + if i == 2: + self.offset_conv2[level] = torch.nn.Conv2d(c, c, 3, 1, 1) + else: + self.offset_conv2[level] = torch.nn.Conv2d( + c + channels[i + 1], c, 3, 1, 1 + ) + self.offset_conv3[level] = torch.nn.Conv2d(c, c, 3, 1, 1) + # apply deform conv + if i == 0: + self.deform_conv[level] = DCNPack(c, 7, 3) + else: + self.deform_conv[level] = DCNPack(c, 8, 3) + self.merge_conv1[level] = torch.nn.Conv2d(c + c + 1, c, 3, 1, 1) + if i < 2: + self.feat_conv[level] = torch.nn.Conv2d(c + channels[i + 1], c, 3, 1, 1) + self.merge_conv2[level] = torch.nn.Conv2d( + c + channels[i + 1], c, 3, 1, 1 + ) + + self.upsample = torch.nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=False + ) + self.lrelu = torch.nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, features, bmasks, Ft2): + outs = [] + + for i in range(2, -1, -1): + level = f"l{i}" + feaL, feaR = features[i] + bmask = bmasks[i] + flow = resize(Ft2, bmask.shape[2:], scale=True) + offset = torch.cat([feaL, feaR, bmask, flow], dim=1) + offset = self.lrelu(self.offset_conv1[level](offset)) + if i == 2: + offset = self.lrelu(self.offset_conv2[level](offset)) + else: + offset = self.lrelu( + self.offset_conv2[level]( + torch.cat([offset, upsampled_offset], dim=1) + ) + ) + offset = self.lrelu(self.offset_conv3[level](offset)) + + warped_feaR = self.deform_conv[level](feaR, offset) + + if i < 2: + warped_feaR = self.feat_conv[level]( + torch.cat([warped_feaR, upsampled_feaR], dim=1) + ) + + merged_feat = self.merge_conv1[level]( + torch.cat([feaL, warped_feaR, bmask], dim=1) + ) + if i < 2: + merged_feat = self.merge_conv2[level]( + torch.cat([merged_feat, upsampled_merged_feat], dim=1) + ) + outs.append(merged_feat) + + if i > 0: # upsample offset and features + warped_feaR = self.lrelu(warped_feaR) + upsampled_offset = self.upsample(offset) * 2 + upsampled_feaR = self.upsample(warped_feaR) + upsampled_merged_feat = self.upsample(merged_feat) + + return outs + + +class AttentionMerge(torch.nn.Module): + def __init__(self, dilate_size=7): + super().__init__() + self.softbinary = torch.nn.ModuleDict() + channels = [35, 64, 96] + for i in range(2, -1, -1): + level = f"{i}" + c = channels[i] + self.softbinary[level] = SoftBinary(c, dilate_size) + + def forward(self, feaL, feaR, bmask): + outs = [] + soft_masks = [] + for i in range(2, -1, -1): + level = f"{i}" + sm = self.softbinary[level](bmask[i], feaL[i], feaR[i]) + soft_masks.append(sm) + x = feaL[i] * (1 - sm) + feaR[i] * sm + outs.append(x) + return outs, soft_masks + + +class Network(torch.torch.nn.Module): + def __init__(self, dilate_size=9): + super().__init__() + cond_c = [35, 64, 96] + self.featurePyramid = FeaturePyramid() + self.deformableAlign = DeformableAlign() + self.attentionMerge = AttentionMerge(dilate_size=dilate_size) + self.multiscaleFuse = MultiscaleFuse(cond_c) + self.condFLownet = CondFlowNet(cond_c, with_bn=False, train_1x1=True, K=16) + + def get_cond(self, inps: list, time: float = 0.5): + tenOne, tenTwo, fflow, bflow = inps + with accelerate.Accelerator().autocast(): + feas, bmasks, Ft2 = self.featurePyramid( + tenOne, tenTwo, [fflow, bflow], time + ) + feaR = self.deformableAlign(feas, bmasks, Ft2)[::-1] + feaL = [feas[i][0] for i in range(3)] + feas, smasks = self.attentionMerge(feaL, feaR, bmasks) + # feas = [F.interpolate(x, scale_factor=0.5, mode="bilinear") for x in feas] + feas = self.multiscaleFuse(feas[::-1]) # downscale by 2 + return feas, smasks + + def normalize(self, x, reverse=False): + # x in [0, 1] + if not reverse: + return x * 2 - 1 + else: + return (x + 1) / 2 + + def forward(self, gt=None, zs=None, inps=[], time=0.5, code="encode"): + if code == "encode": + return self.encode(gt, inps, time) + elif code == "decode": + return self.decode(zs, inps, time) + else: + return self.encode_decode(gt, inps, time, zs=zs) + + def encode(self, gt, inps: list, time: float = 0.5): + img0, img1 = [self.normalize(x) for x in inps[:2]] + gt = self.normalize(gt) + cond = [img0, img1] + inps[-2:] + pixels = thops.pixels(gt) + conds, smasks = self.get_cond(cond, time=time) + + # add random noise before normalizing flow net + loss = 0.0 + if self.training: + gt = gt + ((torch.rand_like(gt, device=gt.device) - 0.5) / 255.0) + loss += -log(255.0) * pixels + log_p, log_det, zs = self.condFLownet(gt, conds) + + loss /= float(log(2) * pixels) + log_p /= float(log(2) * pixels) + log_det /= float(log(2) * pixels) + nll = -(loss + log_det + log_p) + return nll, zs, smasks + + def decode(self, z_list: list, inps: list, time: float = 0.5): + img0, img1 = [self.normalize(x) for x in inps[:2]] + cond = [img0, img1] + inps[-2:] + + conds, smasks = self.get_cond(cond, time=time) + pred = self.condFLownet(z_list, conds, reverse=True) + pred = self.normalize(pred, reverse=True) + return pred, smasks + + def encode_decode(self, gt, inps: list, time: float = 0.5, zs=None): + img0, img1 = [self.normalize(x) for x in inps[:2]] + gt = self.normalize(gt) + cond = [img0, img1] + inps[-2:] + pixels = thops.pixels(gt) + conds, smasks = self.get_cond(cond, time=time) + + # encode first + loss = 0.0 + if self.training: + gt = gt + ((torch.rand_like(gt, device=gt.device) - 0.5) / 255.0) + loss += -log(255.0) * pixels + log_p, log_det, zs_gt = self.condFLownet(gt, conds) + loss /= float(log(2) * pixels) + log_p /= float(log(2) * pixels) + log_det /= float(log(2) * pixels) + nll = -(loss + log_det + log_p) + + # decode next + if zs is None: + heat = torch.sqrt(torch.var(torch.cat([x.flatten() for x in zs_gt]))) + zs = self.get_z(heat, img0.shape[-2:], img0.shape[0], img0.device) + pred = self.condFLownet(zs, conds, reverse=True) + pred = self.normalize(pred, reverse=True) + return nll, pred, smasks + + def get_z(self, heat: float, img_size: tuple, batch: int, device: str): + def calc_z_shapes(img_size, n_levels): + h, w = img_size + z_shapes = [] + channel = 3 + + for _ in range(n_levels - 1): + h //= 2 + w //= 2 + channel *= 2 + z_shapes.append((channel, h, w)) + h //= 2 + w //= 2 + z_shapes.append((channel * 4, h, w)) + return z_shapes + + z_list = [] + z_shapes = calc_z_shapes(img_size, 3) + for z in z_shapes: + z_new = torch.randn(batch, *z, device=device) * heat + z_list.append(z_new) + return z_list diff --git a/src/pervfiarches/generators/PFlowVFI_V2.py b/src/pervfiarches/generators/PFlowVFI_V2.py new file mode 100644 index 00000000..90b56a8b --- /dev/null +++ b/src/pervfiarches/generators/PFlowVFI_V2.py @@ -0,0 +1,383 @@ +"""PerVFI: Soft-binary Blending for Photo-realistic Video Frame Interpolation + +""" +import accelerate +import torch +import torch.nn.functional as F +import logging as logger +from torchvision.ops import DeformConv2d + +from . import thops +from .msfusion import MultiscaleFuse +from .normalizing_flow import * +from .softsplatnet import Basic, Encode, Softmetric +from .softsplatnet.softsplat import softsplat + + +def resize(x, size: tuple, scale: bool): + H, W = x.shape[-2:] + h, w = size + scale_ = h / H + x_ = F.interpolate(x, size, mode="bilinear", align_corners=False) + if scale: + return x_ * scale_ + return x_ + + +def binary_hole(flow): + n, _, h, w = flow.shape + mask = softsplat( + tenIn=torch.ones((n, 1, h, w), device=flow.device), + tenFlow=flow, + tenMetric=None, + strMode="avg", + ) + ones = torch.ones_like(mask, device=mask.device) + zeros = torch.zeros_like(mask, device=mask.device) + out = torch.where(mask <= 0.5, ones, zeros) + return out + + +def warp_pyramid(features: list, metric, flow): + outputs = [] + masks = [] + for lv in range(3): + fea = features[lv] + if lv != 0: + h, w = fea.shape[-2:] + metric = resize(metric, (h, w), scale=False) + flow = resize(flow, (h, w), scale=True) + outputs.append(softsplat(fea, flow, metric.neg().clip(-20.0, 20.0), "soft")) + masks.append(binary_hole(flow)) + return outputs, masks + + +class FeaturePyramid(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netEncode = Encode() + self.netSoftmetric = Softmetric() + + def forward(self, tenOne, tenTwo, tenFlows: list, time=0.5): + F12, _ = tenFlows + x1s, x2s = self.netEncode(tenOne), self.netEncode(tenTwo) + m1t = self.netSoftmetric(x1s, x2s, F12) * 2 * time + F1t = time * F12 + x1s, bmasks = warp_pyramid(x1s, m1t, F1t) + return list(zip(x1s, x2s)), bmasks + + +class SoftBinary(torch.nn.Module): + def __init__(self, cin, dilate_size=5) -> None: + super().__init__() + channel = 64 + reduction = 8 + self.conv1 = torch.nn.Sequential( + *[ + torch.nn.Conv2d(1, channel, dilate_size, 1, padding="same", bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(channel, channel, 3, 1, padding="same", bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(channel, channel, 1, 1, padding="same", bias=False), + ] + ) + self.att = torch.nn.Conv2d(cin * 2, channel, 3, 1, padding="same") + self.avg = torch.nn.AdaptiveAvgPool2d(1) + self.fc = torch.nn.Sequential( + torch.nn.Linear(channel, channel // reduction, bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(channel // reduction, channel, bias=False), + torch.nn.Sigmoid(), + ) + self.conv2 = torch.nn.Conv2d(channel, 1, 1, 1, padding="same", bias=False) + + def forward(self, bmask, feaL, feaR): # N,1,H,W + m_fea = self.conv1(bmask) + x = self.att(torch.cat([feaL, feaR], dim=1)) + b, c, _, _ = x.size() + x = self.avg(x).view(b, c) + x = self.fc(x).view(b, c, 1, 1) + x = m_fea * x.expand_as(x) + x = self.conv2(x) + + x = torch.tanh(torch.abs(x)) + rand_bias = (torch.rand_like(x, device=x.device) - 0.5) / 100.0 + if self.training: + return x + rand_bias + else: + return x + + +class AttentionMerge(torch.nn.Module): + def __init__(self, dilate_size=7): + super().__init__() + self.softbinary = torch.nn.ModuleDict() + channels = [35, 64, 96] + for i in range(2, -1, -1): + level = f"{i}" + c = channels[i] + self.softbinary[level] = SoftBinary(c, dilate_size) + + def forward(self, feaL, feaR, bmask): + outs = [] + soft_masks = [] + for i in range(2, -1, -1): + level = f"{i}" + sm = self.softbinary[level](bmask[i], feaL[i], feaR[i]) + soft_masks.append(sm) + x = feaL[i] * (1 - sm) + feaR[i] * sm + outs.append(x) + return outs, soft_masks + + +class DCNPack(torch.nn.Module): + def __init__(self, cin, groups, dksize): + super().__init__() + cout = groups * 3 * dksize**2 + self.conv_offset = torch.nn.Conv2d(cin, cout, 3, 1, 1) + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + self.dconv = DeformConv2d(cin, cin, dksize, padding=dksize // 2) + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger.info(f"Offset abs mean is {offset_absmean}, larger than 50.") + + return self.dconv(x, offset, mask) + + +class DeformableAlign(torch.nn.Module): + def __init__(self): + super().__init__() + channels = [35, 64, 96] + self.offset_conv1 = torch.nn.ModuleDict() + self.offset_conv2 = torch.nn.ModuleDict() + self.offset_conv3 = torch.nn.ModuleDict() + self.deform_conv = torch.nn.ModuleDict() + self.feat_conv = torch.nn.ModuleDict() + self.merge_conv1 = torch.nn.ModuleDict() + self.merge_conv2 = torch.nn.ModuleDict() + # Pyramids + for i in range(2, -1, -1): + level = f"l{i}" + c = channels[i] + # compute offsets + self.offset_conv1[level] = torch.nn.Conv2d(c * 2 + 1, c, 3, 1, 1) + if i == 2: + self.offset_conv2[level] = torch.nn.Conv2d(c, c, 3, 1, 1) + else: + self.offset_conv2[level] = torch.nn.Conv2d( + c + channels[i + 1], c, 3, 1, 1 + ) + self.offset_conv3[level] = torch.nn.Conv2d(c, c, 3, 1, 1) + # apply deform conv + if i == 0: + self.deform_conv[level] = DCNPack(c, 7, 3) + else: + self.deform_conv[level] = DCNPack(c, 8, 3) + self.merge_conv1[level] = torch.nn.Conv2d(c + c + 1, c, 3, 1, 1) + if i < 2: + self.feat_conv[level] = torch.nn.Conv2d(c + channels[i + 1], c, 3, 1, 1) + self.merge_conv2[level] = torch.nn.Conv2d( + c + channels[i + 1], c, 3, 1, 1 + ) + + self.upsample = torch.nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=False + ) + self.lrelu = torch.nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, features, bmasks): + outs = [] + + for i in range(2, -1, -1): + level = f"l{i}" + feaL, feaR = features[i] + bmask = bmasks[i] + # flow = resize(Ft2, bmask.shape[2:], scale=True) + offset = torch.cat([feaL, feaR, bmask], dim=1) + offset = self.lrelu(self.offset_conv1[level](offset)) + if i == 2: + offset = self.lrelu(self.offset_conv2[level](offset)) + else: + offset = self.lrelu( + self.offset_conv2[level]( + torch.cat([offset, upsampled_offset], dim=1) + ) + ) + offset = self.lrelu(self.offset_conv3[level](offset)) + + warped_feaR = self.deform_conv[level](feaR, offset) + + if i < 2: + warped_feaR = self.feat_conv[level]( + torch.cat([warped_feaR, upsampled_feaR], dim=1) + ) + + merged_feat = self.merge_conv1[level]( + torch.cat([feaL, warped_feaR, bmask], dim=1) + ) + if i < 2: + merged_feat = self.merge_conv2[level]( + torch.cat([merged_feat, upsampled_merged_feat], dim=1) + ) + outs.append(merged_feat) + + if i > 0: # upsample offset and features + warped_feaR = self.lrelu(warped_feaR) + upsampled_offset = self.upsample(offset) * 2 + upsampled_feaR = self.upsample(warped_feaR) + upsampled_merged_feat = self.upsample(merged_feat) + + return outs + + +class Decoder(torch.nn.Module): + def __init__(self, cin): + super().__init__() + self.conv3 = Basic("conv-relu-conv", [cin[2], 128, cin[1]], True) + self.conv2 = Basic("conv-relu-conv", [cin[1] * 2, 256, cin[0]], True) + self.conv1 = Basic("conv-relu-conv", [cin[0] * 2, 96, 128], True) + self.tail = torch.nn.Conv2d(128 // 4, 3, 3, 1, padding="same") + self.up = torch.nn.UpsamplingBilinear2d(scale_factor=2) + # logger.info( + # f"Parameter of decoder: {sum(p.numel() for p in self.parameters())}" + # ) + + def forward(self, xs): + lv1, lv2, lv3 = xs + lv3 = self.conv3(lv3) + lv3 = self.up(lv3) + lv2 = self.conv2(torch.cat([lv3, lv2], dim=1)) + lv2 = self.up(lv2) + lv1 = self.conv1(torch.cat([lv2, lv1], dim=1)) + lv1 = unsqueeze2d(lv1, factor=2) + return self.tail(lv1) + + +class Network_(torch.torch.nn.Module): + def __init__(self, dilate_size=5): + super().__init__() + self.cond_c = [35, 64, 96] + self.featurePyramid = FeaturePyramid() + self.deformableAlign = DeformableAlign() + self.attentionMerge = AttentionMerge(dilate_size=dilate_size) + self.multiscaleFuse = MultiscaleFuse(self.cond_c) + + # self.generator = Decoder(self.cond_c) + # self.condFLownet = CondFlowNet(cond_c, with_bn=False, train_1x1=True) + + def get_cond(self, inps: list, time: float = 0.5): + tenOne, tenTwo, fflow, _ = inps + with accelerate.Accelerator().autocast(): + feas, bmasks = self.featurePyramid(tenOne, tenTwo, [fflow, None], time) + feaR = self.deformableAlign(feas, bmasks)[::-1] + feaL = [feas[i][0] for i in range(3)] + feas, smasks = self.attentionMerge(feaL, feaR, bmasks) + # feas = [F.interpolate(x, scale_factor=0.5, mode="bilinear") for x in feas] + feas = self.multiscaleFuse(feas[::-1]) # downscale by 2 + return feas, smasks + + def normalize(self, x, reverse=False): + # x in [0, 1] + if not reverse: + return x * 2 - 1 + else: + return (x + 1) / 2 + + +class Network_base(Network_): + def __init__(self, dilate_size=5): + super().__init__(dilate_size=dilate_size) + self.generator = Decoder(self.cond_c) + + def normalize(self, x, reverse=False): + # x in [0, 1] + if not reverse: + return x * 2 - 1 + else: + return (x + 1) / 2 + + def forward(self, inps=[], time=0.5, **kwargs): + img0, img1 = [self.normalize(x) for x in inps[:2]] + cond = [img0, img1] + inps[-2:] + + conds, smasks = self.get_cond(cond, time=time) + with accelerate.Accelerator().autocast(): + pred = self.generator(conds) + pred = self.normalize(pred, reverse=True) + return pred, smasks + + +class Network_flow(Network_): + def __init__(self, dilate_size=5): + super().__init__(dilate_size=dilate_size) + self.condFLownet = CondFlowNet(self.cond_c, with_bn=False, train_1x1=True) + + def forward(self, gt=None, zs=None, inps=[], time=0.5, code="encode"): + if code == "encode": + return self.encode(gt, inps, time) + elif code == "decode": + return self.decode(zs, inps, time) + else: + return self.encode_decode(gt, inps, time, zs=zs) + + def encode(self, gt, inps: list, time: float = 0.5): + img0, img1 = [self.normalize(x) for x in inps[:2]] + gt = self.normalize(gt) + cond = [img0, img1] + inps[-2:] + pixels = thops.pixels(gt) + conds, smasks = self.get_cond(cond, time=time) + + # add random noise before normalizing flow net + loss = 0.0 + if self.training: + gt = gt + ((torch.rand_like(gt, device=gt.device) - 0.5) / 255.0) + loss += -log(255.0) * pixels + log_p, log_det, zs = self.condFLownet(gt, conds) + + loss /= float(log(2) * pixels) + log_p /= float(log(2) * pixels) + log_det /= float(log(2) * pixels) + nll = -(loss + log_det + log_p) + return nll, zs, smasks + + def decode(self, z_list: list, inps: list, time: float = 0.5): + img0, img1 = [self.normalize(x) for x in inps[:2]] + cond = [img0, img1] + inps[-2:] + + conds, smasks = self.get_cond(cond, time=time) + pred = self.condFLownet(z_list, conds, reverse=True) + pred = self.normalize(pred, reverse=True) + return pred, smasks + + def encode_decode(self, gt, inps: list, time: float = 0.5, zs=None): + img0, img1 = [self.normalize(x) for x in inps[:2]] + gt = self.normalize(gt) + cond = [img0, img1] + inps[-2:] + pixels = thops.pixels(gt) + conds, smasks = self.get_cond(cond, time=time) + + # encode first + loss = 0.0 + if self.training: + gt = gt + ((torch.rand_like(gt, device=gt.device) - 0.5) / 255.0) + loss += -log(255.0) * pixels + log_p, log_det, zs_gt = self.condFLownet(gt, conds) + loss /= float(log(2) * pixels) + log_p /= float(log(2) * pixels) + log_det /= float(log(2) * pixels) + nll = -(loss + log_det + log_p) + + # decode next + zs = zs_gt if zs is None else zs + pred = self.condFLownet(zs, conds, reverse=True) + pred = self.normalize(pred, reverse=True) + return nll, pred, smasks diff --git a/src/pervfiarches/generators/PFlowVFI_Vb.py b/src/pervfiarches/generators/PFlowVFI_Vb.py new file mode 100644 index 00000000..65529cb9 --- /dev/null +++ b/src/pervfiarches/generators/PFlowVFI_Vb.py @@ -0,0 +1,316 @@ +"""PerVFI: Soft-binary Blending for Photo-realistic Video Frame Interpolation + +""" +import accelerate +import torch +import torch.nn.functional as F +import logging as logger +from torch import Tensor +from torchvision.ops import DeformConv2d + +from .msfusion import MultiscaleFuse +from .normalizing_flow import * +from .softsplatnet import Basic, Encode, Softmetric +from .softsplatnet.softsplat import softsplat + + +def resize(x, size: tuple, scale: bool): + H, W = x.shape[-2:] + h, w = size + scale_ = h / H + x_ = F.interpolate(x, size, mode="bilinear", align_corners=False) + if scale: + return x_ * scale_ + return x_ + + +def binary_hole(flow): + n, _, h, w = flow.shape + mask = softsplat( + tenIn=torch.ones((n, 1, h, w), device=flow.device), + tenFlow=flow, + tenMetric=None, + strMode="avg", + ) + ones = torch.ones_like(mask, device=mask.device) + zeros = torch.zeros_like(mask, device=mask.device) + out = torch.where(mask <= 0.5, ones, zeros) + return out + + +def warp_pyramid(features: list, metric, flow): + outputs = [] + masks = [] + for lv in range(3): + fea = features[lv] + if lv != 0: + h, w = fea.shape[-2:] + metric = resize(metric, (h, w), scale=False) + flow = resize(flow, (h, w), scale=True) + outputs.append(softsplat(fea, flow, metric.neg().clip(-20.0, 20.0), "soft")) + masks.append(binary_hole(flow)) + return outputs, masks + + +class FeaturePyramid(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netEncode = Encode() + self.netSoftmetric = Softmetric() + + def forward( + self, + tenOne, + tenTwo=None, + tenFlows: list[Tensor] = None, + time: float = 0.5, + ): + x1s = self.netEncode(tenOne) + if tenTwo is None: # just encode + return x1s + F12, F21 = tenFlows + x2s = self.netEncode(tenTwo) + m1t = self.netSoftmetric(x1s, x2s, F12) * 2 * time + F1t = time * F12 + m2t = self.netSoftmetric(x2s, x1s, F21) * 2 * (1 - time) + F2t = (1 - time) * F21 + Ft2 = -1 * softsplat(F2t, F2t, m2t.neg().clip(-20.0, 20.0), "soft") + x1s, bmasks = warp_pyramid(x1s, m1t, F1t) + return list(zip(x1s, x2s)), bmasks, Ft2 + + +class SoftBinary(torch.nn.Module): + def __init__(self, cin, dilate_size=7) -> None: + super().__init__() + channel = 64 + reduction = 8 + self.conv1 = torch.nn.Sequential( + *[ + torch.nn.Conv2d(1, channel, dilate_size, 1, padding="same", bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(channel, channel, 3, 1, padding="same", bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(channel, channel, 1, 1, padding="same", bias=False), + ] + ) + self.att = torch.nn.Conv2d(cin * 2, channel, 3, 1, padding="same") + self.avg = torch.nn.AdaptiveAvgPool2d(1) + self.fc = torch.nn.Sequential( + torch.nn.Linear(channel, channel // reduction, bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(channel // reduction, channel, bias=False), + torch.nn.Sigmoid(), + ) + self.conv2 = torch.nn.Conv2d(channel, 1, 1, 1, padding="same", bias=False) + + def forward(self, bmask, feaL, feaR): # N,1,H,W + m_fea = self.conv1(bmask) + x = self.att(torch.cat([feaL, feaR], dim=1)) + b, c, _, _ = x.size() + x = self.avg(x).view(b, c) + x = self.fc(x).view(b, c, 1, 1) + x = m_fea * x.expand_as(x) + x = self.conv2(x) + + x = torch.tanh(torch.abs(x)) + rand_bias = (torch.rand_like(x, device=x.device) - 0.5) / 100.0 + if self.training: + return x + rand_bias + else: + return x + + +class DCNPack(torch.nn.Module): + def __init__(self, cin, groups, dksize): + super().__init__() + cout = groups * 3 * dksize**2 + self.conv_offset = torch.nn.Conv2d(cin, cout, 3, 1, 1) + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + self.dconv = DeformConv2d(cin, cin, dksize, padding=dksize // 2) + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger.info(f"Offset abs mean is {offset_absmean}, larger than 50.") + + return self.dconv(x, offset, mask) + + +class DeformableAlign(torch.nn.Module): + def __init__(self): + super().__init__() + channels = [35, 64, 96] + self.offset_conv1 = torch.nn.ModuleDict() + self.offset_conv2 = torch.nn.ModuleDict() + self.offset_conv3 = torch.nn.ModuleDict() + self.deform_conv = torch.nn.ModuleDict() + self.feat_conv = torch.nn.ModuleDict() + self.merge_conv1 = torch.nn.ModuleDict() + self.merge_conv2 = torch.nn.ModuleDict() + # Pyramids + for i in range(2, -1, -1): + level = f"l{i}" + c = channels[i] + # compute offsets + self.offset_conv1[level] = torch.nn.Conv2d(c * 2 + 3, c, 3, 1, 1) + if i == 2: + self.offset_conv2[level] = torch.nn.Conv2d(c, c, 3, 1, 1) + else: + self.offset_conv2[level] = torch.nn.Conv2d( + c + channels[i + 1], c, 3, 1, 1 + ) + self.offset_conv3[level] = torch.nn.Conv2d(c, c, 3, 1, 1) + # apply deform conv + if i == 0: + self.deform_conv[level] = DCNPack(c, 7, 3) + else: + self.deform_conv[level] = DCNPack(c, 8, 3) + self.merge_conv1[level] = torch.nn.Conv2d(c + c + 1, c, 3, 1, 1) + if i < 2: + self.feat_conv[level] = torch.nn.Conv2d(c + channels[i + 1], c, 3, 1, 1) + self.merge_conv2[level] = torch.nn.Conv2d( + c + channels[i + 1], c, 3, 1, 1 + ) + + self.upsample = torch.nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=False + ) + self.lrelu = torch.nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, features, bmasks, Ft2): + outs = [] + + for i in range(2, -1, -1): + level = f"l{i}" + feaL, feaR = features[i] + bmask = bmasks[i] + flow = resize(Ft2, bmask.shape[2:], scale=True) + offset = torch.cat([feaL, feaR, bmask, flow], dim=1) + offset = self.lrelu(self.offset_conv1[level](offset)) + if i == 2: + offset = self.lrelu(self.offset_conv2[level](offset)) + else: + offset = self.lrelu( + self.offset_conv2[level]( + torch.cat([offset, upsampled_offset], dim=1) + ) + ) + offset = self.lrelu(self.offset_conv3[level](offset)) + + warped_feaR = self.deform_conv[level](feaR, offset) + + if i < 2: + warped_feaR = self.feat_conv[level]( + torch.cat([warped_feaR, upsampled_feaR], dim=1) + ) + + merged_feat = self.merge_conv1[level]( + torch.cat([feaL, warped_feaR, bmask], dim=1) + ) + if i < 2: + merged_feat = self.merge_conv2[level]( + torch.cat([merged_feat, upsampled_merged_feat], dim=1) + ) + outs.append(merged_feat) + + if i > 0: # upsample offset and features + warped_feaR = self.lrelu(warped_feaR) + upsampled_offset = self.upsample(offset) * 2 + upsampled_feaR = self.upsample(warped_feaR) + upsampled_merged_feat = self.upsample(merged_feat) + + return outs + + +class AttentionMerge(torch.nn.Module): + def __init__(self, dilate_size=7): + super().__init__() + self.softbinary = torch.nn.ModuleDict() + channels = [35, 64, 96] + for i in range(2, -1, -1): + level = f"{i}" + c = channels[i] + self.softbinary[level] = SoftBinary(c, dilate_size) + + def forward(self, feaL, feaR, bmask): + outs = [] + soft_masks = [] + for i in range(2, -1, -1): + level = f"{i}" + sm = self.softbinary[level](bmask[i], feaL[i], feaR[i]) + soft_masks.append(sm) + x = feaL[i] * (1 - sm) + feaR[i] * sm + outs.append(x) + return outs, soft_masks + + +class Decoder(torch.nn.Module): + def __init__(self, cin): + super().__init__() + self.conv3 = Basic("conv-relu-conv", [cin[2], 128, cin[1]], True) + self.conv2 = Basic("conv-relu-conv", [cin[1] * 2, 256, cin[0]], True) + self.conv1 = Basic("conv-relu-conv", [cin[0] * 2, 96, 128], True) + self.tail = torch.nn.Conv2d(128 // 4, 3, 3, 1, padding="same") + self.up = torch.nn.UpsamplingBilinear2d(scale_factor=2) + logger.info( + f"Parameter of decoder: {sum(p.numel() for p in self.parameters())}" + ) + + def forward(self, xs): + lv1, lv2, lv3 = xs + lv3 = self.conv3(lv3) + lv3 = self.up(lv3) + lv2 = self.conv2(torch.cat([lv3, lv2], dim=1)) + lv2 = self.up(lv2) + lv1 = self.conv1(torch.cat([lv2, lv1], dim=1)) + lv1 = unsqueeze2d(lv1, factor=2) + return self.tail(lv1) + + +class Network(torch.torch.nn.Module): + def __init__(self, dilate_size=9): + super().__init__() + cond_c = [35, 64, 96] + self.featurePyramid = FeaturePyramid() + self.deformableAlign = DeformableAlign() + self.attentionMerge = AttentionMerge(dilate_size=dilate_size) + self.multiscaleFuse = MultiscaleFuse(cond_c) + self.generator = Decoder(cond_c) + # self.condFLownet = CondFlowNet(cond_c, with_bn=False, train_1x1=True) + + def get_cond(self, inps: list, time: float = 0.5): + tenOne, tenTwo, fflow, bflow = inps + with accelerate.Accelerator().autocast(): + feas, bmasks, Ft2 = self.featurePyramid( + tenOne, tenTwo, [fflow, bflow], time + ) + feaR = self.deformableAlign(feas, bmasks, Ft2)[::-1] + feaL = [feas[i][0] for i in range(3)] + feas, smasks = self.attentionMerge(feaL, feaR, bmasks) + # feas = [F.interpolate(x, scale_factor=0.5, mode="bilinear") for x in feas] + feas = self.multiscaleFuse(feas[::-1]) # downscale by 2 + return feas, smasks + + def normalize(self, x, reverse=False): + # x in [0, 1] + if not reverse: + return x * 2 - 1 + else: + return (x + 1) / 2 + + def forward(self, inps=[], time=0.5, **kwargs): + img0, img1 = [self.normalize(x) for x in inps[:2]] + cond = [img0, img1] + inps[-2:] + + conds, smasks = self.get_cond(cond, time=time) + with accelerate.Accelerator().autocast(): + pred = self.generator(conds) + pred = self.normalize(pred, reverse=True) + return pred, smasks diff --git a/src/pervfiarches/generators/PFlowVFI_ablation.py b/src/pervfiarches/generators/PFlowVFI_ablation.py new file mode 100644 index 00000000..baf26aea --- /dev/null +++ b/src/pervfiarches/generators/PFlowVFI_ablation.py @@ -0,0 +1,331 @@ +"""PerVFI: Soft-binary Blending for Photo-realistic Video Frame Interpolation + +""" +import accelerate +import torch +import torch.nn.functional as F +import logging as logger +from torch import Tensor +from torchvision.ops import DeformConv2d + +from .msfusion import MultiscaleFuse +from .normalizing_flow import * +from .softsplatnet import Basic, Encode, Softmetric +from .softsplatnet.softsplat import softsplat + + +def resize(x, size: tuple, scale: bool): + H, W = x.shape[-2:] + h, w = size + scale_ = h / H + x_ = F.interpolate(x, size, mode="bilinear", align_corners=False) + if scale: + return x_ * scale_ + return x_ + + +def binary_hole(flow): + n, _, h, w = flow.shape + mask = softsplat( + tenIn=torch.ones((n, 1, h, w), device=flow.device), + tenFlow=flow, + tenMetric=None, + strMode="avg", + ) + ones = torch.ones_like(mask, device=mask.device) + zeros = torch.zeros_like(mask, device=mask.device) + out = torch.where(mask <= 0.5, ones, zeros) + return out + + +def warp_pyramid(features: list, metric, flow): + outputs = [] + masks = [] + for lv in range(3): + fea = features[lv] + if lv != 0: + h, w = fea.shape[-2:] + metric = resize(metric, (h, w), scale=False) + flow = resize(flow, (h, w), scale=True) + outputs.append(softsplat(fea, flow, metric.neg().clip(-20.0, 20.0), "soft")) + masks.append(binary_hole(flow)) + return outputs, masks + + +class FeaturePyramid(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netEncode = Encode() + self.netSoftmetric = Softmetric() + + def forward( + self, + tenOne, + tenTwo=None, + tenFlows: list[Tensor] = None, + time: float = 0.5, + ): + x1s = self.netEncode(tenOne) + if tenTwo is None: # just encode + return x1s + F12, F21 = tenFlows + x2s = self.netEncode(tenTwo) + m1t = self.netSoftmetric(x1s, x2s, F12) * 2 * time + F1t = time * F12 + m2t = self.netSoftmetric(x2s, x1s, F21) * 2 * (1 - time) + F2t = (1 - time) * F21 + Ft2 = -1 * softsplat(F2t, F2t, m2t.neg().clip(-20.0, 20.0), "soft") + x1s, bmasks = warp_pyramid(x1s, m1t, F1t) + return list(zip(x1s, x2s)), bmasks, Ft2 + + +class SoftBinary(torch.nn.Module): + def __init__( + self, cin, dilate_size=7, mask_type="quasi-binary", noise=True + ) -> None: + super().__init__() + channel = 64 + reduction = 8 + assert mask_type in ["quasi-binary", "binary", "adaptive"] + self.dilation = not (mask_type == "binary") # no dilation if binary + self.adaptive = mask_type == "adaptive" # use adaptive mask w/o sparsity + self.use_noise = noise + bias = self.adaptive + + if self.dilation: + self.conv1 = torch.nn.Sequential( + *[ + torch.nn.Conv2d( + 1, channel, dilate_size, 1, padding="same", bias=bias + ), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(channel, channel, 3, 1, padding="same", bias=bias), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(channel, channel, 1, 1, padding="same", bias=bias), + ] + ) + self.att = torch.nn.Conv2d(cin * 2, channel, 3, 1, padding="same") + self.avg = torch.nn.AdaptiveAvgPool2d(1) + self.fc = torch.nn.Sequential( + torch.nn.Linear(channel, channel // reduction, bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(channel // reduction, channel, bias=False), + torch.nn.Sigmoid(), + ) + self.conv2 = torch.nn.Conv2d(channel, 1, 1, 1, padding="same", bias=bias) + + def forward(self, bmask, feaL, feaR): # N,1,H,W + if self.dilation: + m_fea = self.conv1(bmask) + x = self.att(torch.cat([feaL, feaR], dim=1)) + b, c, _, _ = x.size() + x = self.avg(x).view(b, c) + x = self.fc(x).view(b, c, 1, 1) + x = m_fea * x.expand_as(x) + x = self.conv2(x) + if self.adaptive: + x = torch.sigmoid(x) + else: + x = torch.tanh(torch.abs(x)) + else: + x = bmask + if self.use_noise and self.training: + rand_bias = (torch.rand_like(x, device=x.device) - 0.5) / 100.0 + return x + rand_bias + return x + + +class DCNPack(torch.nn.Module): + def __init__(self, cin, groups, dksize): + super().__init__() + cout = groups * 3 * dksize**2 + self.conv_offset = torch.nn.Conv2d(cin, cout, 3, 1, 1) + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + self.dconv = DeformConv2d(cin, cin, dksize, padding=dksize // 2) + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger.info(f"Offset abs mean is {offset_absmean}, larger than 50.") + + return self.dconv(x, offset, mask) + + +class DeformableAlign(torch.nn.Module): + def __init__(self): + super().__init__() + channels = [35, 64, 96] + self.offset_conv1 = torch.nn.ModuleDict() + self.offset_conv2 = torch.nn.ModuleDict() + self.offset_conv3 = torch.nn.ModuleDict() + self.deform_conv = torch.nn.ModuleDict() + self.feat_conv = torch.nn.ModuleDict() + self.merge_conv1 = torch.nn.ModuleDict() + self.merge_conv2 = torch.nn.ModuleDict() + # Pyramids + for i in range(2, -1, -1): + level = f"l{i}" + c = channels[i] + # compute offsets + self.offset_conv1[level] = torch.nn.Conv2d(c * 2 + 3, c, 3, 1, 1) + if i == 2: + self.offset_conv2[level] = torch.nn.Conv2d(c, c, 3, 1, 1) + else: + self.offset_conv2[level] = torch.nn.Conv2d( + c + channels[i + 1], c, 3, 1, 1 + ) + self.offset_conv3[level] = torch.nn.Conv2d(c, c, 3, 1, 1) + # apply deform conv + if i == 0: + self.deform_conv[level] = DCNPack(c, 7, 3) + else: + self.deform_conv[level] = DCNPack(c, 8, 3) + self.merge_conv1[level] = torch.nn.Conv2d(c + c + 1, c, 3, 1, 1) + if i < 2: + self.feat_conv[level] = torch.nn.Conv2d(c + channels[i + 1], c, 3, 1, 1) + self.merge_conv2[level] = torch.nn.Conv2d( + c + channels[i + 1], c, 3, 1, 1 + ) + + self.upsample = torch.nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=False + ) + self.lrelu = torch.nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, features, bmasks, Ft2): + outs = [] + + for i in range(2, -1, -1): + level = f"l{i}" + feaL, feaR = features[i] + bmask = bmasks[i] + flow = resize(Ft2, bmask.shape[2:], scale=True) + offset = torch.cat([feaL, feaR, bmask, flow], dim=1) + offset = self.lrelu(self.offset_conv1[level](offset)) + if i == 2: + offset = self.lrelu(self.offset_conv2[level](offset)) + else: + offset = self.lrelu( + self.offset_conv2[level]( + torch.cat([offset, upsampled_offset], dim=1) + ) + ) + offset = self.lrelu(self.offset_conv3[level](offset)) + + warped_feaR = self.deform_conv[level](feaR, offset) + + if i < 2: + warped_feaR = self.feat_conv[level]( + torch.cat([warped_feaR, upsampled_feaR], dim=1) + ) + + merged_feat = self.merge_conv1[level]( + torch.cat([feaL, warped_feaR, bmask], dim=1) + ) + if i < 2: + merged_feat = self.merge_conv2[level]( + torch.cat([merged_feat, upsampled_merged_feat], dim=1) + ) + outs.append(merged_feat) + + if i > 0: # upsample offset and features + warped_feaR = self.lrelu(warped_feaR) + upsampled_offset = self.upsample(offset) * 2 + upsampled_feaR = self.upsample(warped_feaR) + upsampled_merged_feat = self.upsample(merged_feat) + + return outs + + +class AttentionMerge(torch.nn.Module): + def __init__(self, dilate_size=7, **kwargs): + super().__init__() + self.softbinary = torch.nn.ModuleDict() + channels = [35, 64, 96] + for i in range(2, -1, -1): + level = f"{i}" + c = channels[i] + self.softbinary[level] = SoftBinary(c, dilate_size, **kwargs) + + def forward(self, feaL, feaR, bmask): + outs = [] + soft_masks = [] + for i in range(2, -1, -1): + level = f"{i}" + sm = self.softbinary[level](bmask[i], feaL[i], feaR[i]) + soft_masks.append(sm) + x = feaL[i] * (1 - sm) + feaR[i] * sm + outs.append(x) + return outs, soft_masks + + +class Decoder(torch.nn.Module): + def __init__(self, cin): + super().__init__() + self.conv3 = Basic("conv-relu-conv", [cin[2], 128, cin[1]], True) + self.conv2 = Basic("conv-relu-conv", [cin[1] * 2, 256, cin[0]], True) + self.conv1 = Basic("conv-relu-conv", [cin[0] * 2, 96, 128], True) + self.tail = torch.nn.Conv2d(128 // 4, 3, 3, 1, padding="same") + self.up = torch.nn.UpsamplingBilinear2d(scale_factor=2) + logger.info( + f"Parameter of decoder: {sum(p.numel() for p in self.parameters())}" + ) + + def forward(self, xs): + lv1, lv2, lv3 = xs + lv3 = self.conv3(lv3) + lv3 = self.up(lv3) + lv2 = self.conv2(torch.cat([lv3, lv2], dim=1)) + lv2 = self.up(lv2) + lv1 = self.conv1(torch.cat([lv2, lv1], dim=1)) + lv1 = unsqueeze2d(lv1, factor=2) + return self.tail(lv1) + + +class Network(torch.torch.nn.Module): + def __init__(self, dilate_size=9, **kwargs): + super().__init__() + cond_c = [35, 64, 96] + self.featurePyramid = FeaturePyramid() + self.deformableAlign = DeformableAlign() + self.attentionMerge = AttentionMerge(dilate_size=dilate_size, **kwargs) + self.multiscaleFuse = MultiscaleFuse(cond_c) + self.generator = Decoder(cond_c) + # self.condFLownet = CondFlowNet(cond_c, with_bn=False, train_1x1=True) + + def get_cond(self, inps: list, time: float = 0.5): + tenOne, tenTwo, fflow, bflow = inps + with accelerate.Accelerator().autocast(): + feas, bmasks, Ft2 = self.featurePyramid( + tenOne, tenTwo, [fflow, bflow], time + ) + feaR = self.deformableAlign(feas, bmasks, Ft2)[::-1] + feaL = [feas[i][0] for i in range(3)] + feas, smasks = self.attentionMerge(feaL, feaR, bmasks) + # feas = [F.interpolate(x, scale_factor=0.5, mode="bilinear") for x in feas] + feas = self.multiscaleFuse(feas[::-1]) # downscale by 2 + return feas, smasks + + def normalize(self, x, reverse=False): + # x in [0, 1] + if not reverse: + return x * 2 - 1 + else: + return (x + 1) / 2 + + def forward(self, inps=[], time=0.5, **kwargs): + img0, img1 = [self.normalize(x) for x in inps[:2]] + cond = [img0, img1] + inps[-2:] + + conds, smasks = self.get_cond(cond, time=time) + with accelerate.Accelerator().autocast(): + pred = self.generator(conds) + pred = self.normalize(pred, reverse=True) + return pred, smasks diff --git a/src/pervfiarches/generators/PFlowVFI_adaptive.py b/src/pervfiarches/generators/PFlowVFI_adaptive.py new file mode 100644 index 00000000..50622f12 --- /dev/null +++ b/src/pervfiarches/generators/PFlowVFI_adaptive.py @@ -0,0 +1,198 @@ +"""PerVFI: Fully-Adaptive mask + +""" +import accelerate +import torch +import torch.nn.functional as F +import logging as logger +from torch import Tensor + +from .msfusion import MultiscaleFuse +from .normalizing_flow import * +from .softsplatnet import Basic, Encode, Softmetric +from .softsplatnet.softsplat import softsplat + + +def resize(x, size: tuple, scale: bool): + H, W = x.shape[-2:] + h, w = size + scale_ = h / H + x_ = F.interpolate(x, size, mode="bilinear", align_corners=False) + if scale: + return x_ * scale_ + return x_ + + +def binary_hole(flow): + n, _, h, w = flow.shape + mask = softsplat( + tenIn=torch.ones((n, 1, h, w), device=flow.device), + tenFlow=flow, + tenMetric=None, + strMode="avg", + ) + ones = torch.ones_like(mask, device=mask.device) + zeros = torch.zeros_like(mask, device=mask.device) + out = torch.where(mask <= 0.5, ones, zeros) + return out + + +def warp_pyramid(features: list, metric, flow): + outputs = [] + masks = [] + for lv in range(3): + fea = features[lv] + if lv != 0: + h, w = fea.shape[-2:] + metric = resize(metric, (h, w), scale=False) + flow = resize(flow, (h, w), scale=True) + outputs.append(softsplat(fea, flow, metric.neg().clip(-20.0, 20.0), "soft")) + masks.append(binary_hole(flow)) + return outputs, masks + + +class FeaturePyramid(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netEncode = Encode() + self.netSoftmetric = Softmetric() + + def forward( + self, + tenOne, + tenTwo=None, + tenFlows: list[Tensor] = None, + time: float = 0.5, + ): + x1s = self.netEncode(tenOne) + x2s = self.netEncode(tenTwo) + if tenTwo is None: # just encode + return x1s + F12, F21 = tenFlows + m1t = self.netSoftmetric(x1s, x2s, F12) * 2 * time + F1t = time * F12 + m2t = self.netSoftmetric(x2s, x1s, F21) * 2 * (1 - time) + F2t = (1 - time) * F21 + x1s, bmasks1 = warp_pyramid(x1s, m1t, F1t) + x2s, bmasks2 = warp_pyramid(x2s, m2t, F2t) + return list(zip(x1s, x2s)), bmasks1, bmasks2 + + +class SoftBinary(torch.nn.Module): + def __init__(self, cin, dilate_size=7) -> None: + super().__init__() + channel = 64 + reduction = 8 + self.conv1 = torch.nn.Sequential( + *[ + torch.nn.Conv2d(2, channel, dilate_size, 1, padding="same", bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(channel, channel, 3, 1, padding="same", bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(channel, channel, 1, 1, padding="same", bias=False), + ] + ) + self.att = torch.nn.Conv2d(cin * 2, channel, 3, 1, padding="same") + self.avg = torch.nn.AdaptiveAvgPool2d(1) + self.fc = torch.nn.Sequential( + torch.nn.Linear(channel, channel // reduction, bias=False), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(channel // reduction, channel, bias=False), + torch.nn.Sigmoid(), + ) + self.conv2 = torch.nn.Conv2d(channel, 1, 1, 1, padding="same", bias=False) + + def forward(self, bmask: list, feaL, feaR): + m_fea = self.conv1(torch.cat(bmask, dim=1)) + x = self.att(torch.cat([feaL, feaR], dim=1)) + b, c, _, _ = x.size() + x = self.avg(x).view(b, c) + x = self.fc(x).view(b, c, 1, 1) + x = m_fea * x.expand_as(x) + x = self.conv2(x) + return torch.sigmoid(x) + + +class AttentionMerge(torch.nn.Module): + def __init__(self, dilate_size=7, **kwargs): + super().__init__() + self.softbinary = torch.nn.ModuleDict() + channels = [35, 64, 96] + for i in range(2, -1, -1): + level = f"{i}" + c = channels[i] + self.softbinary[level] = SoftBinary(c, dilate_size, **kwargs) + + def forward(self, feaL, feaR, bmask1, bmask2): + outs = [] + soft_masks = [] + for i in range(2, -1, -1): + level = f"{i}" + sm = self.softbinary[level]([bmask1[i], bmask2[i]], feaL[i], feaR[i]) + soft_masks.append(sm) + x = feaL[i] * (1 - sm) + feaR[i] * sm + outs.append(x) + return outs, soft_masks + + +class Decoder(torch.nn.Module): + def __init__(self, cin): + super().__init__() + self.conv3 = Basic("conv-relu-conv", [cin[2], 128, cin[1]], True) + self.conv2 = Basic("conv-relu-conv", [cin[1] * 2, 256, cin[0]], True) + self.conv1 = Basic("conv-relu-conv", [cin[0] * 2, 96, 128], True) + self.tail = torch.nn.Conv2d(128 // 4, 3, 3, 1, padding="same") + self.up = torch.nn.UpsamplingBilinear2d(scale_factor=2) + logger.info( + f"Parameter of decoder: {sum(p.numel() for p in self.parameters())}" + ) + + def forward(self, xs): + lv1, lv2, lv3 = xs + lv3 = self.conv3(lv3) + lv3 = self.up(lv3) + lv2 = self.conv2(torch.cat([lv3, lv2], dim=1)) + lv2 = self.up(lv2) + lv1 = self.conv1(torch.cat([lv2, lv1], dim=1)) + lv1 = unsqueeze2d(lv1, factor=2) + return self.tail(lv1) + + +class Network(torch.torch.nn.Module): + def __init__(self, dilate_size=9, **kwargs): + super().__init__() + cond_c = [35, 64, 96] + self.featurePyramid = FeaturePyramid() + self.attentionMerge = AttentionMerge(dilate_size=dilate_size, **kwargs) + self.multiscaleFuse = MultiscaleFuse(cond_c) + self.generator = Decoder(cond_c) + + def get_cond(self, inps: list, time: float = 0.5): + tenOne, tenTwo, fflow, bflow = inps + with accelerate.Accelerator().autocast(): + feas, bmasks1, bmasks2 = self.featurePyramid( + tenOne, tenTwo, [fflow, bflow], time + ) + feaL = [feas[i][0] for i in range(3)] + feaR = [feas[i][1] for i in range(3)] + feas, smasks = self.attentionMerge(feaL, feaR, bmasks1, bmasks2) + feas = self.multiscaleFuse(feas[::-1]) # downscale by 2 + return feas, smasks + + def normalize(self, x, reverse=False): + # x in [0, 1] + if not reverse: + return x * 2 - 1 + else: + return (x + 1) / 2 + + def forward(self, inps=[], time=0.5, **kwargs): + img0, img1 = [self.normalize(x) for x in inps[:2]] + cond = [img0, img1] + inps[-2:] + + conds, smasks = self.get_cond(cond, time=time) + with accelerate.Accelerator().autocast(): + pred = self.generator(conds) + pred = self.normalize(pred, reverse=True) + return pred, smasks diff --git a/src/pervfiarches/generators/__init__.py b/src/pervfiarches/generators/__init__.py new file mode 100644 index 00000000..07b82213 --- /dev/null +++ b/src/pervfiarches/generators/__init__.py @@ -0,0 +1,40 @@ +def build_generator_arch(version): + if version.lower() == "v00": + from .PFlowVFI_V0 import Network + + model = Network(dilate_size=9) + + ################## ABLATION ################## + elif version.lower() == "ab_b_n": + from .PFlowVFI_ablation import Network + + model = Network(dilate_size=7, mask_type="binary", noise=True) + elif version.lower() == "ab_b_nf": + from .PFlowVFI_ablation import Network + + model = Network(dilate_size=7, mask_type="binary", noise=False) + elif version.lower() == "ab_qb_nf": + from .PFlowVFI_ablation import Network + + model = Network(dilate_size=7, mask_type="quasi-binary", noise=False) + elif version.lower() == "ab_a": + from .PFlowVFI_adaptive import Network + + model = Network(dilate_size=7) + ################## ABLATION ################## + + elif version.lower() == "vb": + from .PFlowVFI_Vb import Network + + model = Network(9) + + elif version.lower() in ["v20_nll", "v20_laper"]: + from .PFlowVFI_V2 import Network_flow + + model = Network_flow(5) + elif version.lower() == "v2b": + from .PFlowVFI_V2 import Network_base + + model = Network_base(5) + + return model diff --git a/src/pervfiarches/generators/msfusion.py b/src/pervfiarches/generators/msfusion.py new file mode 100644 index 00000000..0af9d853 --- /dev/null +++ b/src/pervfiarches/generators/msfusion.py @@ -0,0 +1,515 @@ +"""Multi-Scale Feature Fusion Network (follow HRNet) +TODO: need to clear, but now it is usable. +""" +import os + +import torch +import torch.nn as nn +import logging as logger + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + + out = self.conv2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + + out = self.conv2(out) + out = self.relu(out) + + out = self.conv3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__( + self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True, + ): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels + ) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels + ) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(True) + + def _check_branches( + self, num_branches, blocks, num_blocks, num_inchannels, num_channels + ): + if num_branches != len(num_blocks): + error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( + num_branches, len(num_blocks) + ) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( + num_branches, len(num_channels) + ) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( + num_branches, len(num_inchannels) + ) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): + downsample = None + if ( + stride != 1 + or self.num_inchannels[branch_index] + != num_channels[branch_index] * block.expansion + ): + downsample = nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ) + + layers = [] + layers.append( + block( + self.num_inchannels[branch_index], + num_channels[branch_index], + stride, + downsample, + ) + ) + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block(self.num_inchannels[branch_index], num_channels[branch_index]) + ) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False, + ), + nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"), + ) + ) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False, + ) + ) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False, + ), + nn.ReLU(True), + ) + ) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck} +cfg = { + "MODEL": { + "EXTRA": { + "STAGE2": { + "NUM_MODULES": 2, + "NUM_BRANCHES": 2, + "BLOCK": "BASIC", + "NUM_BLOCKS": [2, 2], + "NUM_CHANNELS": [32, 64], + "FUSE_METHOD": "SUM", + }, + "STAGE3": { + "NUM_MODULES": 2, + "NUM_BRANCHES": 3, + "BLOCK": "BASIC", + "NUM_BLOCKS": [2, 2, 2], + "NUM_CHANNELS": [32, 64, 96], + "FUSE_METHOD": "SUM", + }, + "FINAL_CONV_KERNEL": 1, + "PRETRAINED_LAYERS": [], + }, + } +} + + +class MultiscaleFuse(nn.Module): + def __init__(self, cins=[35, 64, 96]): + self.inplanes = 64 + extra = cfg["MODEL"]["EXTRA"] + super().__init__() + + # stem net + self.stems = nn.ModuleList( + [nn.Sequential(conv3x3(cins[i], 64), nn.ReLU(True)) for i in range(3)] + ) + + # stage 1 + stage_1_block = Bottleneck + self.layer1 = self._make_layer(stage_1_block, 64, 4) + pre_stage_channels = [stage_1_block.expansion * 64] + + # stage 2 + self.stage2_cfg = extra["STAGE2"] + num_channels = self.stage2_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage2_cfg["BLOCK"]] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition1 = self._make_transition_layer(pre_stage_channels, num_channels) + self.extra1 = nn.Sequential( + conv3x3(64 + num_channels[-1], num_channels[-1]), nn.ReLU(True) + ) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels + ) + + self.stage3_cfg = extra["STAGE3"] + num_channels = self.stage3_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage3_cfg["BLOCK"]] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.extra2 = nn.Sequential( + conv3x3(64 + num_channels[-1], num_channels[-1]), nn.ReLU(True) + ) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels + ) + + self.final_layers = nn.ModuleList( + [ + nn.Conv2d( + in_channels=num_channels[i], + out_channels=cins[i], + kernel_size=3, + stride=2, + padding=1, + ) + for i in range(3) + ] + ) + + self.pretrained_layers = extra["PRETRAINED_LAYERS"] + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False, + ), + nn.ReLU(inplace=True), + ) + ) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = ( + num_channels_cur_layer[i] + if j == i - num_branches_pre + else inchannels + ) + conv3x3s.append( + nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.ReLU(inplace=True), + ) + ) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + num_modules = layer_config["NUM_MODULES"] + num_branches = layer_config["NUM_BRANCHES"] + num_blocks = layer_config["NUM_BLOCKS"] + num_channels = layer_config["NUM_CHANNELS"] + block = blocks_dict[layer_config["BLOCK"]] + fuse_method = layer_config["FUSE_METHOD"] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + ) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, xs): + x0 = self.stems[0](xs[0]) + x1 = self.stems[1](xs[1]) + x2 = self.stems[2](xs[2]) + + x = self.layer1(x0) + # print('STAGE1') + # print('after bottleneck:', x.shape) + x_list = [] + for i in range(self.stage2_cfg["NUM_BRANCHES"]): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + # print('after transition:', [i.shape for i in x_list]) + + # print('concatenating') + x_list[-1] = self.extra1(torch.cat([x1, x_list[-1]], dim=1)) + # print('after concatenating:', [i.shape for i in x_list]) + + # print('STAGE2') + y_list = self.stage2(x_list) + # print('after fusion:', [i.shape for i in y_list]) + x_list = [] + for i in range(self.stage3_cfg["NUM_BRANCHES"]): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + # print('after transition:', [i.shape for i in x_list]) + + # print('concatenating') + x_list[-1] = self.extra2(torch.cat([x2, x_list[-1]], dim=1)) + # print('after concatenating:', [i.shape for i in x_list]) + + # print('STAGE3') + y_list = self.stage3(x_list) + # print('after fusion:', [i.shape for i in y_list]) + + x_list = [] + for i in range(len(y_list)): + x_list.append(self.final_layers[i](y_list[i])) + + return x_list + + def init_weights(self, pretrained=""): + logger.info("=> init weights from normal distribution") + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ["bias"]: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + nn.init.normal_(m.weight, std=0.001) + for name, _ in m.named_parameters(): + if name in ["bias"]: + nn.init.constant_(m.bias, 0) + + if os.path.isfile(pretrained): + pretrained_state_dict = torch.load(pretrained) + logger.info("=> loading pretrained model {}".format(pretrained)) + + need_init_state_dict = {} + for name, m in pretrained_state_dict.items(): + if ( + name.split(".")[0] in self.pretrained_layers + or self.pretrained_layers[0] == "*" + ): + need_init_state_dict[name] = m + self.load_state_dict(need_init_state_dict, strict=False) + elif pretrained: + logger.error("=> please download pre-trained models first!") + raise ValueError("{} is not exist!".format(pretrained)) diff --git a/src/pervfiarches/generators/normalizing_flow.py b/src/pervfiarches/generators/normalizing_flow.py new file mode 100644 index 00000000..b1bf4cda --- /dev/null +++ b/src/pervfiarches/generators/normalizing_flow.py @@ -0,0 +1,461 @@ +"""Basic layers and funcs for Normalizing Flow Network +""" +from math import log, pi + +import numpy as np +import torch + +from scipy import linalg as la +from torch import nn +from torch.nn import functional as F + +from . import thops + + +class ActNorm(nn.Module): + def __init__(self, in_channel): + super().__init__() + + self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) + self.logs = nn.Parameter(torch.ones(1, in_channel, 1, 1)) + + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + logs = torch.log(1 / (std + 1e-6) + 1e-6) + self.logs.data.copy_(logs.data) + + def forward(self, input, reverse=False): + if self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + if reverse: + return input * (torch.exp(-self.logs)) - self.loc + else: + logdet = thops.pixels(input) * torch.sum(self.logs) + return (torch.exp(self.logs) + 1e-8) * (input + self.loc), logdet + + +class InvConv2d(nn.Module): + def __init__(self, in_channel): + super().__init__() + + weight = torch.randn(in_channel, in_channel) + q, _ = torch.qr(weight) + weight = q[:, :, None, None] + self.weight = nn.Parameter(weight) + + def forward(self, input, reverse=False): + _, _, height, width = input.shape + + if reverse: + return F.conv2d(input, self.weight.squeeze().inverse()[:, :, None, None]) + + else: + logdet = ( + thops.pixels(input) + * torch.slogdet(self.weight.squeeze().double())[1].float() + ) + return F.conv2d(input, self.weight), logdet + + +class InvConv2dLU(nn.Module): + """invertible conv2d with LU decompose""" + + def __init__(self, in_channel): + super().__init__() + + weight = np.random.randn(in_channel, in_channel) + q, _ = la.qr(weight) + w_p, w_l, w_u = la.lu(q.astype(np.float32)) + w_s = np.diag(w_u) + w_u = np.triu(w_u, 1) + u_mask = np.triu(np.ones_like(w_u), 1) + l_mask = u_mask.T + + w_p = torch.from_numpy(w_p) + w_l = torch.from_numpy(w_l) + w_s = torch.from_numpy(w_s) + w_u = torch.from_numpy(w_u) + + self.register_buffer("w_p", w_p) + self.register_buffer("u_mask", torch.from_numpy(u_mask)) + self.register_buffer("l_mask", torch.from_numpy(l_mask)) + self.register_buffer("s_sign", torch.sign(w_s)) + self.register_buffer("l_eye", torch.eye(l_mask.shape[0])) + self.w_l = nn.Parameter(w_l) + self.w_s = nn.Parameter(torch.log(torch.abs(w_s) + 1e-6)) + self.w_u = nn.Parameter(w_u) + + def forward(self, input, reverse=False): + weight = self.calc_weight() + + if reverse: + return F.conv2d(input, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3)) + else: + logdet = thops.pixels(input) * torch.sum(self.w_s) + return F.conv2d(input, weight), logdet + + def calc_weight(self): + weight = ( + self.w_p + @ (self.w_l * self.l_mask + self.l_eye) + @ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * torch.exp(self.w_s))) + ) + return weight.unsqueeze(2).unsqueeze(3) + + +class ZeroConv2d(nn.Module): + """The 3x3 convolution in which weight and bias are initialized with zero. + The output is then scaled with a positive learnable param. + """ + + def __init__(self, in_channel, out_channel): + super().__init__() + + self.conv = nn.Conv2d(in_channel, out_channel, 3, padding=1) + self.conv.weight.data.zero_() + self.conv.bias.data.zero_() + self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + + def forward(self, input): + out = self.conv(input) + out = out * torch.exp(self.scale * 3) + return out + + +class condAffineCouplingBN(nn.Module): + """The conditional affine coupling layer with BN layer.""" + + def __init__(self, cin, ccond): + super().__init__() + + class condAffineNet(nn.Module): + def __init__(self, cin, ccond, cout) -> None: + super().__init__() + self.tail = nn.Sequential( + nn.Conv2d(cin + ccond, 64, 3, 1, 1), + nn.ReLU(), + nn.Conv2d(64, 64, 1, 1, 0), + nn.ReLU(), + ZeroConv2d(64, cout), + ) + + def forward(self, x, cond): + return self.tail(torch.cat([x, cond], dim=1)) + + self.affine = condAffineNet(cin // 2, ccond, cin) + self.bn = nn.BatchNorm2d(cin // 2, affine=False) + self.scale = nn.Parameter(torch.zeros(1), requires_grad=True) + self.scale_shift = nn.Parameter(torch.zeros(1), requires_grad=True) + + def forward(self, input, cond, reverse=False): + if not reverse: + return self.encode(input, cond) + else: + return self.decode(input, cond) + + def encode(self, x, cond): + # split into two folds, 'off' is to generate shift and log_rescale + # FIXME BN layer is not revertable yet. + on, off = x.chunk(2, 1) + shift, log_rescale = self.affine(off, cond).chunk(2, 1) + log_rescale = self.scale * torch.tanh(log_rescale) + self.scale_shift + # affine + on = on * torch.exp(log_rescale) + shift + logdet = thops.sum(log_rescale, dim=[1, 2, 3]) + # bn + if self.training: + mean = torch.mean(on, dim=(1, 2, 3)) + var = torch.mean((on - mean) ** 2, dim=(1, 2, 3)) + else: + var = self.bn.running_var + on = self.bn(on) + # mean, var = self.bn.running_mean, self.bn.running_var + print("encode mean: ", torch.mean(mean), "var: ", torch.mean(var)) + logdet = logdet - 0.5 * torch.log(torch.mean(var) + 1e-5) + + output = torch.cat([on, off], 1) + return output, logdet + + def decode(self, x, cond): + on, off = x.chunk(2, 1) + shift, log_rescale = self.affine(off, cond).chunk(2, 1) + log_rescale = self.scale * torch.tanh(log_rescale) + self.scale_shift + + mean, var = self.bn.running_mean, self.bn.running_var + print("decode mean: ", torch.mean(mean), "var: ", torch.mean(var)) + mean = mean.reshape(-1, 1, 1, 1).transpose(0, 1) + var = var.reshape(-1, 1, 1, 1).transpose(0, 1) + on = on * torch.exp(0.5 * torch.log(var + 1e-5)) + mean + on = (on - shift) * torch.exp(-log_rescale) + x = torch.cat([on, off], 1) + return x + + +class condAffineCoupling(nn.Module): + """The conditional affine coupling layer""" + + def __init__(self, cin, ccond): + class condAffineNet(nn.Module): + def __init__(self, cin, ccond, cout) -> None: + super().__init__() + self.tail = nn.Sequential( + nn.Conv2d(cin + ccond, 64, 3, 1, 1), + nn.ReLU(), + nn.Conv2d(64, 64, 1, 1, 0), + nn.ReLU(), + ZeroConv2d(64, cout), + ) + + def forward(self, x, cond): + return self.tail(torch.cat([x, cond], dim=1)) + + super().__init__() + self.affine = condAffineNet(cin // 2, ccond, cin) + self.scale = nn.Parameter(torch.zeros(1), requires_grad=True) + self.scale_shift = nn.Parameter(torch.zeros(1), requires_grad=True) + + def forward(self, input, cond, reverse=False): + if not reverse: + return self.encode(input, cond) + else: + return self.decode(input, cond) + + def encode(self, x, cond): + on, off = x.chunk(2, 1) + shift, log_rescale = self.affine(off, cond).chunk(2, 1) + log_rescale = self.scale * torch.tanh(log_rescale) + self.scale_shift + on = on * torch.exp(log_rescale) + shift + output = torch.cat([on, off], 1) + logdet = thops.sum(log_rescale, dim=[1, 2, 3]) + + return output, logdet + + def decode(self, x, cond): + on, off = x.chunk(2, 1) + shift, log_rescale = self.affine(off, cond).chunk(2, 1) + log_rescale = self.scale * torch.tanh(log_rescale) + self.scale_shift + on = (on - shift) * torch.exp(-log_rescale) + x = torch.cat([on, off], 1) + return x + + +class Flow(nn.Module): + """Flow step, contains actnorm -> invconv -> condAffineCouple""" + + def __init__(self, in_channel, cond_channel, with_bn=False, train_1x1=True): + super().__init__() + + self.actnorm = ActNorm(in_channel) + + self.invconv = InvConv2d(in_channel) + if not train_1x1: + for p in self.invconv.parameters(): + p.requires_grad = False # no need to train + + if with_bn: + self.coupling = condAffineCouplingBN(in_channel, cond_channel) + else: + self.coupling = condAffineCoupling(in_channel, cond_channel) + + def forward(self, input, cond, reverse=False): + if not reverse: + x, logdet = self.actnorm(input) + x, det1 = self.invconv(x) + out, det2 = self.coupling(x, cond) + logdet = logdet + det1 + det2 + return out, logdet + else: + x = self.coupling(input, cond, reverse) + x = self.invconv(x, reverse) + out = self.actnorm(x, reverse) + return out + + +class Block(nn.Module): + """Each block contains: squeeze -> flowstep ... flowstep -> split""" + + def __init__( + self, K, in_channel, cond_channel, split: bool, with_bn: bool, train_1x1: bool + ): + super().__init__() + + self.K = K # number of flow steps + + squeeze_dim = in_channel * 4 + self.split = split + + # layers + self.actnorm = ActNorm(squeeze_dim) + self.invconv = InvConv2d(squeeze_dim) + self.flows = nn.ModuleList( + [Flow(squeeze_dim, cond_channel, with_bn, train_1x1) for _ in range(self.K)] + ) + if not train_1x1: + for p in self.invconv.parameters(): + p.requires_grad = False # no need to train + + if self.split: + self.prior = ZeroConv2d(in_channel * 2, in_channel * 4) + else: + self.prior = ZeroConv2d(in_channel * 4, in_channel * 8) + + def forward(self, input, cond, eps=None, reverse=False): + if not reverse: + out, logdet, log_p, z_new = self.encode(input, cond) + return out, logdet, log_p, z_new + else: + out = self.decode(input, cond, eps) + return out + + def encode(self, input, cond): + b_size = input.shape[0] + out = squeeze2d(input, 2) + out, logdet = self.actnorm(out) + out, det = self.invconv(out) + logdet += det + + for flow in self.flows: + out, det = flow(out, cond) + logdet = logdet + det + if self.split: + out, z_new = out.chunk(2, 1) + mean, log_sd = self.prior(out).chunk(2, 1) + log_p = gaussian_log_p(z_new, mean, log_sd) + log_p = log_p.view(b_size, -1).sum(1) + else: + zero = torch.zeros_like(out) + mean, log_sd = self.prior(zero).chunk(2, 1) + log_p = gaussian_log_p(out, mean, log_sd) + log_p = log_p.view(b_size, -1).sum(1) + z_new = out + + return out, logdet, log_p, z_new + + def decode(self, output, cond, eps=None): + # eps: noise + input = output + + if self.split: + mean, log_sd = self.prior(input).chunk(2, 1) + z = gaussian_sample(eps, mean, log_sd) + input = torch.cat([output, z], 1) + else: + zero = torch.zeros_like(input) + # zero = F.pad(zero, [1, 1, 1, 1], value=1) + mean, log_sd = self.prior(zero).chunk(2, 1) + z = gaussian_sample(eps, mean, log_sd) + input = z + + for flow in self.flows[::-1]: + input = flow(input, cond, reverse=True) + + input = self.invconv(input, reverse=True) + input = self.actnorm(input, reverse=True) + + unsqueezed = unsqueeze2d(input) + return unsqueezed + + +class CondFlowNet(nn.Module): + """Conditional Normalizing Flow Network""" + + def __init__(self, cins: list, with_bn: bool, train_1x1: bool, K=4, **kwargs): + super().__init__() + self.L = 3 # block number + # three blocks at three scales, each has 4,4,4 flowsteps. + self.blocks = nn.ModuleList() + conf = dict(with_bn=with_bn, train_1x1=train_1x1) + self.blocks.append(Block(K, 3, cins[0], split=True, **conf)) + self.blocks.append(Block(K, 3 * 2, cins[1], split=True, **conf)) + self.blocks.append(Block(K, 3 * 4, cins[2], split=False, **conf)) + # logger.info( + # f"Parameter of condflownet: {sum(p.numel() for p in self.parameters())}" + # ) + + def forward(self, input, conds: list, reverse=False): + if not reverse: + log_p_sum, logdet, z_outs = self.encode(input, conds) + return log_p_sum, logdet, z_outs + else: + z_list = input + out = self.decode(z_list, conds[::-1]) + return out + + def encode(self, input: torch.Tensor, conds: list): + log_p = 0 + logdet = 0 + z_outs = [] + + for i, block in enumerate(self.blocks): + input, det, logp, z_new = block(input, conds[i]) + z_outs.append(z_new) + logdet = logdet + det + log_p = log_p + logp + + return log_p, logdet, z_outs + + def decode(self, z_list: list, conds: list): + for i, block in enumerate(self.blocks[::-1]): + if i == 0: + input = block(z_list[-1], conds[0], z_list[-1], reverse=True) + else: + input = block(input, conds[i], z_list[-(i + 1)], reverse=True) + + return input + + +def gaussian_log_p(x, mean, log_sd): + return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 * torch.exp(-2 * log_sd) + + +def gaussian_sample(eps, mean, log_sd): + return mean + torch.exp(log_sd) * eps + + +def squeeze2d(input, factor=2): + assert factor >= 1 and isinstance(factor, int) + if factor == 1: + return input + B, C, H, W = input.shape + factor2 = factor**2 + assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor)) + + x = input.view(B, C, H // factor, factor, W // factor, factor) + x = x.permute(0, 1, 3, 5, 2, 4).contiguous() + x = x.view(B, C * factor2, H // factor, W // factor) + return x + + +def unsqueeze2d(input, factor=2): + assert factor >= 1 and isinstance(factor, int) + if factor == 1: + return input + B, C, H, W = input.shape + factor2 = factor**2 + assert C % (factor2) == 0, "{}".format(C) + + x = input.view(B, C // factor2, factor, factor, H, W) + x = x.permute(0, 1, 4, 2, 5, 3).contiguous() + x = x.view(B, C // (factor2), H * factor, W * factor) + return x diff --git a/src/pervfiarches/generators/softsplatnet/__init__.py b/src/pervfiarches/generators/softsplatnet/__init__.py new file mode 100644 index 00000000..8fe4b3c5 --- /dev/null +++ b/src/pervfiarches/generators/softsplatnet/__init__.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python + +import torch +import torch.nn.functional as TF + +from . import correlation # the custom cost volume layer +from . import softsplat # the custom softmax splatting layer + +backwarp_tenGrid = {} + + +def backwarp(tenIn, tenFlow): + if str(tenFlow.shape) not in backwarp_tenGrid: + tenHor = ( + torch.linspace( + start=-1.0, + end=1.0, + steps=tenFlow.shape[3], + dtype=tenFlow.dtype, + device=tenFlow.device, + ) + .view(1, 1, 1, -1) + .repeat(1, 1, tenFlow.shape[2], 1) + ) + tenVer = ( + torch.linspace( + start=-1.0, + end=1.0, + steps=tenFlow.shape[2], + dtype=tenFlow.dtype, + device=tenFlow.device, + ) + .view(1, 1, -1, 1) + .repeat(1, 1, 1, tenFlow.shape[3]) + ) + + backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda() + # end + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + return torch.nn.functional.grid_sample( + input=tenIn, + grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ) + + +def binary(x, threshold): + ones = torch.ones_like(x, device=x.device) + zeros = torch.zeros_like(x, device=x.device) + return torch.where(x <= threshold, ones, zeros) + + +def calc_hole(x, flow): + hole = binary( + softsplat.softsplat( + tenIn=torch.ones_like(x, device=x.device), + tenFlow=flow, + tenMetric=None, + strMode="avg", + ), + 0.5, + ) + return hole + + +class Decoder(torch.nn.Module): + def __init__(self, intChannels): + super().__init__() + + self.netMain = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intChannels, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=96, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=96, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=64, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + ), + ) + + # end + + def forward(self, tenOne, tenTwo, objPrevious): + intWidth = tenOne.shape[3] and tenTwo.shape[3] + intHeight = tenOne.shape[2] and tenTwo.shape[2] + + tenMain = None + + if objPrevious is None: + tenVolume = correlation.FunctionCorrelation(tenOne=tenOne, tenTwo=tenTwo) + + tenMain = torch.cat([tenOne, tenVolume], 1) + + elif objPrevious is not None: + tenForward = ( + torch.nn.functional.interpolate( + input=objPrevious["tenForward"], + size=(intHeight, intWidth), + mode="bilinear", + ) + / float(objPrevious["tenForward"].shape[3]) + * float(intWidth) + ) + + tenVolume = correlation.FunctionCorrelation( + tenOne=tenOne, tenTwo=backwarp(tenTwo, tenForward) + ) + + tenMain = torch.cat([tenOne, tenVolume, tenForward], 1) + + # end + + return {"tenForward": self.netMain(tenMain)} + + # end + + +class Extractor(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netFirst = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netSecond = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=16, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netThird = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netFourth = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=64, + out_channels=96, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=96, + out_channels=96, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netFifth = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=96, + out_channels=128, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=128, + out_channels=128, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + self.netSixth = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=128, + out_channels=192, + kernel_size=3, + stride=2, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d( + in_channels=192, + out_channels=192, + kernel_size=3, + stride=1, + padding=1, + ), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + + # end + + def forward(self, tenInput): + tenFirst = self.netFirst(tenInput) + tenSecond = self.netSecond(tenFirst) + tenThird = self.netThird(tenSecond) + tenFourth = self.netFourth(tenThird) + tenFifth = self.netFifth(tenFourth) + tenSixth = self.netSixth(tenFifth) + + return [tenFirst, tenSecond, tenThird, tenFourth, tenFifth, tenSixth] + + # end + + +class Basic(torch.nn.Module): + def __init__(self, strType, intChannels, boolSkip): + super().__init__() + + if strType == "relu-conv-relu-conv": + self.netMain = torch.nn.Sequential( + torch.nn.PReLU(num_parameters=intChannels[0], init=0.25), + torch.nn.Conv2d( + in_channels=intChannels[0], + out_channels=intChannels[1], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + torch.nn.PReLU(num_parameters=intChannels[1], init=0.25), + torch.nn.Conv2d( + in_channels=intChannels[1], + out_channels=intChannels[2], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ) + + elif strType == "conv-relu-conv": + self.netMain = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=intChannels[0], + out_channels=intChannels[1], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + torch.nn.PReLU(num_parameters=intChannels[1], init=0.25), + torch.nn.Conv2d( + in_channels=intChannels[1], + out_channels=intChannels[2], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ) + + # end + + self.boolSkip = boolSkip + + if boolSkip == True: + if intChannels[0] == intChannels[2]: + self.netShortcut = None + + elif intChannels[0] != intChannels[2]: + self.netShortcut = torch.nn.Conv2d( + in_channels=intChannels[0], + out_channels=intChannels[2], + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + + def forward(self, tenInput): + if self.boolSkip == False: + return self.netMain(tenInput) + # end + + if self.netShortcut is None: + return self.netMain(tenInput) + tenInput + + elif self.netShortcut is not None: + return self.netMain(tenInput) + self.netShortcut(tenInput) + + +class Downsample(torch.nn.Module): + def __init__(self, intChannels): + super().__init__() + + self.netMain = torch.nn.Sequential( + torch.nn.PReLU(num_parameters=intChannels[0], init=0.25), + torch.nn.Conv2d( + in_channels=intChannels[0], + out_channels=intChannels[1], + kernel_size=3, + stride=2, + padding=1, + bias=False, + ), + torch.nn.PReLU(num_parameters=intChannels[1], init=0.25), + torch.nn.Conv2d( + in_channels=intChannels[1], + out_channels=intChannels[2], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ) + + # end + + def forward(self, tenInput): + return self.netMain(tenInput) + + +class Upsample(torch.nn.Module): + def __init__(self, intChannels): + super().__init__() + + self.netMain = torch.nn.Sequential( + torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + torch.nn.PReLU(num_parameters=intChannels[0], init=0.25), + torch.nn.Conv2d( + in_channels=intChannels[0], + out_channels=intChannels[1], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + torch.nn.PReLU(num_parameters=intChannels[1], init=0.25), + torch.nn.Conv2d( + in_channels=intChannels[1], + out_channels=intChannels[2], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ) + + # end + + def forward(self, tenInput): + return self.netMain(tenInput) + + +class Encode(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=3, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + torch.nn.PReLU(num_parameters=32, init=0.25), + torch.nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + torch.nn.PReLU(num_parameters=32, init=0.25), + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ), + torch.nn.PReLU(num_parameters=64, init=0.25), + torch.nn.Conv2d( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + torch.nn.PReLU(num_parameters=64, init=0.25), + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels=64, + out_channels=96, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ), + torch.nn.PReLU(num_parameters=96, init=0.25), + torch.nn.Conv2d( + in_channels=96, + out_channels=96, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + torch.nn.PReLU(num_parameters=96, init=0.25), + ) + + # end + + def forward(self, tenInput): + tenOutput = [] + + tenOutput.append(self.netOne(tenInput)) + tenOutput.append(self.netTwo(tenOutput[-1])) + tenOutput.append(self.netThr(tenOutput[-1])) + + return [torch.cat([tenInput, tenOutput[0]], 1)] + tenOutput[1:] + + +class Softmetric(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netInput = torch.nn.Conv2d( + in_channels=3, + out_channels=12, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.netError = torch.nn.Conv2d( + in_channels=1, + out_channels=4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + + for intRow, intFeatures in [(0, 16), (1, 32), (2, 64), (3, 96)]: + self.add_module( + str(intRow) + "x0" + " - " + str(intRow) + "x1", + Basic( + "relu-conv-relu-conv", + [intFeatures, intFeatures, intFeatures], + True, + ), + ) + # end + + for intCol in [0]: + self.add_module( + "0x" + str(intCol) + " - " + "1x" + str(intCol), + Downsample([16, 32, 32]), + ) + self.add_module( + "1x" + str(intCol) + " - " + "2x" + str(intCol), + Downsample([32, 64, 64]), + ) + self.add_module( + "2x" + str(intCol) + " - " + "3x" + str(intCol), + Downsample([64, 96, 96]), + ) + # end + + for intCol in [1]: + self.add_module( + "3x" + str(intCol) + " - " + "2x" + str(intCol), + Upsample([96, 64, 64]), + ) + self.add_module( + "2x" + str(intCol) + " - " + "1x" + str(intCol), + Upsample([64, 32, 32]), + ) + self.add_module( + "1x" + str(intCol) + " - " + "0x" + str(intCol), + Upsample([32, 16, 16]), + ) + # end + + self.netOutput = Basic("conv-relu-conv", [16, 16, 1], True) + + # end + + def forward(self, tenEncone, tenEnctwo, tenFlow): + tenColumn = [None, None, None, None] + + tenColumn[0] = torch.cat( + [ + self.netInput(tenEncone[0][:, 0:3, :, :]), + self.netError( + torch.nn.functional.l1_loss( + input=tenEncone[0], + target=backwarp(tenEnctwo[0], tenFlow), + reduction="none", + ).mean([1], True) + ), + ], + 1, + ) + tenColumn[1] = self._modules["0x0 - 1x0"](tenColumn[0]) + tenColumn[2] = self._modules["1x0 - 2x0"](tenColumn[1]) + tenColumn[3] = self._modules["2x0 - 3x0"](tenColumn[2]) + + intColumn = 1 + for intRow in range(len(tenColumn) - 1, -1, -1): + tenColumn[intRow] = self._modules[ + str(intRow) + + "x" + + str(intColumn - 1) + + " - " + + str(intRow) + + "x" + + str(intColumn) + ](tenColumn[intRow]) + if intRow != len(tenColumn) - 1: + tenUp = self._modules[ + str(intRow + 1) + + "x" + + str(intColumn) + + " - " + + str(intRow) + + "x" + + str(intColumn) + ](tenColumn[intRow + 1]) + + if tenUp.shape[2] != tenColumn[intRow].shape[2]: + tenUp = torch.nn.functional.pad( + input=tenUp, + pad=[0, 0, 0, -1], + mode="constant", + value=0.0, + ) + if tenUp.shape[3] != tenColumn[intRow].shape[3]: + tenUp = torch.nn.functional.pad( + input=tenUp, + pad=[0, -1, 0, 0], + mode="constant", + value=0.0, + ) + + tenColumn[intRow] = tenColumn[intRow] + tenUp + # end + # end + + return self.netOutput(tenColumn[0]) + + +class Warp(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + + def forward( + self, + tenEncone, + tenEnctwo, + tenMetricone, + tenMetrictwo, + tenForward, + tenBackward, + ): + tenOutput = [] + tenMasks = [] + + for intLevel in range(3): + tenOne = tenEncone[intLevel] + tenTwo = tenEnctwo[intLevel] + H, W = tenOne.shape[-2:] + h, w = tenForward.shape[-2:] + if intLevel != 0: + tenMetricone = TF.interpolate( + tenMetricone, size=(H, W), mode="bilinear" + ) + tenMetrictwo = TF.interpolate( + tenMetrictwo, size=(H, W), mode="bilinear" + ) + + tenForward = TF.interpolate( + tenForward, size=(H, W), mode="bilinear" + ) * (H / h) + tenBackward = TF.interpolate( + tenBackward, size=(H, W), mode="bilinear" + ) * (H / h) + + tenOutput.append( + [ + softsplat.softsplat( + tenIn=torch.cat([tenOne, tenMetricone], 1), + tenFlow=tenForward, + tenMetric=tenMetricone.neg().clip(-20.0, 20.0), + strMode="soft", + ), + softsplat.softsplat( + tenIn=torch.cat([tenTwo, tenMetrictwo], 1), + tenFlow=tenBackward, + tenMetric=tenMetrictwo.neg().clip(-20.0, 20.0), + strMode="soft", + ), + ] + ) + tenMasks.append( + [ + calc_hole(tenMetricone, tenForward), + calc_hole(tenMetrictwo, tenBackward), + ] + ) + tenMetricone = TF.interpolate( + tenMetricone, size=(H // 2, W // 2), mode="bilinear" + ) + tenMetrictwo = TF.interpolate( + tenMetrictwo, size=(H // 2, W // 2), mode="bilinear" + ) + + tenForward = ( + TF.interpolate(tenForward, size=(H // 2, W // 2), mode="bilinear") + * (H // 2) + / h + ) + tenBackward = ( + TF.interpolate(tenBackward, size=(H // 2, W // 2), mode="bilinear") + * (H // 2) + / h + ) + tenMasks.append( + [ + calc_hole(tenMetricone, tenForward), + calc_hole(tenMetrictwo, tenBackward), + ] + ) + + return tenOutput, tenMasks + + +class Flow(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netExtractor = Extractor() + + self.netFirst = Decoder(16 + 81 + 2) + self.netSecond = Decoder(32 + 81 + 2) + self.netThird = Decoder(64 + 81 + 2) + self.netFourth = Decoder(96 + 81 + 2) + self.netFifth = Decoder(128 + 81 + 2) + self.netSixth = Decoder(192 + 81) + + # end + + def forward(self, tenOne, tenTwo): + intWidth = tenOne.shape[3] and tenTwo.shape[3] + intHeight = tenOne.shape[2] and tenTwo.shape[2] + + tenOne = self.netExtractor(tenOne) + tenTwo = self.netExtractor(tenTwo) + + objForward = None + objBackward = None + + objForward = self.netSixth(tenOne[-1], tenTwo[-1], objForward) + objBackward = self.netSixth(tenTwo[-1], tenOne[-1], objBackward) + + objForward = self.netFifth(tenOne[-2], tenTwo[-2], objForward) + objBackward = self.netFifth(tenTwo[-2], tenOne[-2], objBackward) + + objForward = self.netFourth(tenOne[-3], tenTwo[-3], objForward) + objBackward = self.netFourth(tenTwo[-3], tenOne[-3], objBackward) + + objForward = self.netThird(tenOne[-4], tenTwo[-4], objForward) + objBackward = self.netThird(tenTwo[-4], tenOne[-4], objBackward) + + objForward = self.netSecond(tenOne[-5], tenTwo[-5], objForward) + objBackward = self.netSecond(tenTwo[-5], tenOne[-5], objBackward) + + objForward = self.netFirst(tenOne[-6], tenTwo[-6], objForward) + objBackward = self.netFirst(tenTwo[-6], tenOne[-6], objBackward) + + return { + "tenForward": torch.nn.functional.interpolate( + input=objForward["tenForward"], + size=(intHeight, intWidth), + mode="bilinear", + align_corners=False, + ) + * (float(intWidth) / float(objForward["tenForward"].shape[3])), + "tenBackward": torch.nn.functional.interpolate( + input=objBackward["tenForward"], + size=(intHeight, intWidth), + mode="bilinear", + align_corners=False, + ) + * (float(intWidth) / float(objBackward["tenForward"].shape[3])), + } diff --git a/src/pervfiarches/generators/softsplatnet/correlation.py b/src/pervfiarches/generators/softsplatnet/correlation.py new file mode 100644 index 00000000..688c2e2a --- /dev/null +++ b/src/pervfiarches/generators/softsplatnet/correlation.py @@ -0,0 +1,462 @@ +#!/usr/bin/env python + +import cupy +import re +import torch + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradOne = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradOne( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradOne, + float* gradTwo + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradOne); // channels + int l = (intIndex / SIZE_1(gradOne)) % SIZE_3(gradOne) + 4; // w-pos + int m = (intIndex / SIZE_1(gradOne) / SIZE_3(gradOne)) % SIZE_2(gradOne) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradOne); + const int bot0index = ((n * SIZE_2(gradOne)) + (m-4)) * SIZE_3(gradOne) + (l-4); + gradOne[bot0index + intSample*SIZE_1(gradOne)*SIZE_2(gradOne)*SIZE_3(gradOne)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradTwo = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradTwo( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradOne, + float* gradTwo + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradTwo); // channels + int l = (intIndex / SIZE_1(gradTwo)) % SIZE_3(gradTwo) + 4; // w-pos + int m = (intIndex / SIZE_1(gradTwo) / SIZE_3(gradTwo)) % SIZE_2(gradTwo) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradTwo); + const int bot1index = ((n * SIZE_2(gradTwo)) + (m-4)) * SIZE_3(gradTwo) + (l-4); + gradTwo[bot1index + intSample*SIZE_1(gradTwo)*SIZE_2(gradTwo)*SIZE_3(gradTwo)] = sum / (float)sumelems; + } } +''' + + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace( + objMatch.group(), + str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == + False else intSizes[intArg].item())) + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ + '((' + + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + + ')*' + + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == + False else intStrides[intArg].item()) + ')' + for intArg in range(intArgs) + ] + + strKernel = strKernel.replace( + objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel + + +# end + + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +# end + + +class _FunctionCorrelation(torch.autograd.Function): + + @staticmethod + def forward(self, one, two): + rbot0 = one.new_zeros( + [one.shape[0], one.shape[2] + 8, one.shape[3] + 8, one.shape[1]]) + rbot1 = one.new_zeros( + [one.shape[0], one.shape[2] + 8, one.shape[3] + 8, one.shape[1]]) + + one = one.contiguous() + assert (one.is_cuda == True) + two = two.contiguous() + assert (two.is_cuda == True) + + output = one.new_zeros([one.shape[0], 81, one.shape[2], one.shape[3]]) + + if one.is_cuda == True: + n = one.shape[2] * one.shape[3] + cupy_launch( + 'kernel_Correlation_rearrange', + cupy_kernel('kernel_Correlation_rearrange', { + 'input': one, + 'output': rbot0 + }))(grid=tuple( + [int((n + 16 - 1) / 16), one.shape[1], one.shape[0]]), + block=tuple([16, 1, 1]), + args=[cupy.int32(n), + one.data_ptr(), + rbot0.data_ptr()]) + + n = two.shape[2] * two.shape[3] + cupy_launch( + 'kernel_Correlation_rearrange', + cupy_kernel('kernel_Correlation_rearrange', { + 'input': two, + 'output': rbot1 + }))(grid=tuple( + [int((n + 16 - 1) / 16), two.shape[1], two.shape[0]]), + block=tuple([16, 1, 1]), + args=[cupy.int32(n), + two.data_ptr(), + rbot1.data_ptr()]) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch( + 'kernel_Correlation_updateOutput', + cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))(grid=tuple( + [output.shape[3], output.shape[2], output.shape[0]]), + block=tuple([32, 1, 1]), + shared_mem=one.shape[1] * 4, + args=[ + cupy.int32(n), + rbot0.data_ptr(), + rbot1.data_ptr(), + output.data_ptr() + ]) + + elif one.is_cuda == False: + raise NotImplementedError() + + # end + + self.save_for_backward(one, two, rbot0, rbot1) + + return output + + # end + + @staticmethod + def backward(self, gradOutput): + one, two, rbot0, rbot1 = self.saved_tensors + + gradOutput = gradOutput.contiguous() + assert (gradOutput.is_cuda == True) + + gradOne = one.new_zeros([ + one.shape[0], one.shape[1], one.shape[2], one.shape[3] + ]) if self.needs_input_grad[0] == True else None + gradTwo = one.new_zeros([ + one.shape[0], one.shape[1], one.shape[2], one.shape[3] + ]) if self.needs_input_grad[1] == True else None + + if one.is_cuda == True: + if gradOne is not None: + for intSample in range(one.shape[0]): + n = one.shape[1] * one.shape[2] * one.shape[3] + cupy_launch( + 'kernel_Correlation_updateGradOne', + cupy_kernel( + 'kernel_Correlation_updateGradOne', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradOne': gradOne, + 'gradTwo': None + }))(grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cupy.int32(n), intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + gradOutput.data_ptr(), + gradOne.data_ptr(), None + ]) + # end + # end + + if gradTwo is not None: + for intSample in range(one.shape[0]): + n = one.shape[1] * one.shape[2] * one.shape[3] + cupy_launch( + 'kernel_Correlation_updateGradTwo', + cupy_kernel( + 'kernel_Correlation_updateGradTwo', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradOne': None, + 'gradTwo': gradTwo + }))(grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cupy.int32(n), intSample, + rbot0.data_ptr(), + rbot1.data_ptr(), + gradOutput.data_ptr(), None, + gradTwo.data_ptr() + ]) + # end + # end + + elif one.is_cuda == False: + raise NotImplementedError() + + # end + + return gradOne, gradTwo + + # end + + +# end + + +def FunctionCorrelation(tenOne, tenTwo): + return _FunctionCorrelation.apply(tenOne, tenTwo) + + +# end + + +class ModuleCorrelation(torch.nn.Module): + + def __init__(self): + super().__init__() + + # end + + def forward(self, tenOne, tenTwo): + return _FunctionCorrelation.apply(tenOne, tenTwo) + + # end + + +# end diff --git a/src/pervfiarches/generators/softsplatnet/softsplat.py b/src/pervfiarches/generators/softsplatnet/softsplat.py new file mode 100644 index 00000000..6f613c0a --- /dev/null +++ b/src/pervfiarches/generators/softsplatnet/softsplat.py @@ -0,0 +1,608 @@ +#!/usr/bin/env python + +import collections +import os +import re +import typing + +import cupy +import torch + +########################################################## + +objCudacache = {} + + +def cuda_int32(intIn: int): + return cupy.int32(intIn) + + +# end + + +def cuda_float32(fltIn: float): + return cupy.float32(fltIn) + + +# end + + +def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert (False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', + str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', + str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', + str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', + objValue) + + elif type(objValue + ) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue + ) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue + ) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue + ) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue + ) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue + ) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert (False) + + elif True: + print(strVariable, type(objValue)) + assert (False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace( + objMatch.group(), + str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == + False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert (intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace( + '{', '(').replace('}', ')').strip() + ')*' + + str(intStrides[intArg] if torch. + is_tensor(intStrides[intArg]) == + False else intStrides[intArg].item()) + + ')') + # end + + strKernel = strKernel.replace( + 'OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert (intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace( + '{', '(').replace('}', ')').strip() + ')*' + + str(intStrides[intArg] if torch. + is_tensor(intStrides[intArg]) == + False else intStrides[intArg].item()) + + ')') + # end + + strKernel = strKernel.replace( + 'VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey + + +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey: str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() + # end + + return cupy.cuda.compile_with_cache( + objCudacache[strKey]['strKernel'], + tuple([ + '-I ' + os.environ['CUDA_HOME'], + '-I ' + os.environ['CUDA_HOME'] + '/include' + ])).get_function(objCudacache[strKey]['strFunction']) + + +# end + +########################################################## + + +def softsplat(tenIn: torch.Tensor, tenFlow: torch.Tensor, + tenMetric: torch.Tensor, strMode: str) -> torch.Tensor: + assert (strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) + + if strMode == 'sum': assert (tenMetric is None) + if strMode == 'avg': assert (tenMetric is None) + if strMode.split('-')[0] == 'linear': assert (tenMetric is not None) + if strMode.split('-')[0] == 'soft': assert (tenMetric is not None) + + if strMode == 'avg': + tenIn = torch.cat([ + tenIn, + tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]) + ], 1) + + elif strMode.split('-')[0] == 'linear': + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split('-')[0] == 'soft': + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if strMode.split('-')[0] in ['avg', 'linear', 'soft']: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split('-')) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'addeps': + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'zeroeps': + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split('-')[1] == 'clipeps': + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut + + +# end + + +class softsplat_func(torch.autograd.Function): + + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros( + [tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + + if tenIn.is_cuda == True: + cuda_launch( + cuda_kernel( + 'softsplat_out', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOut': tenOut + }))(grid=tuple( + [int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenOut.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOut.data_ptr() + ], + stream=collections.namedtuple('Stream', 'ptr')( + torch.cuda.current_stream().cuda_stream)) + + elif tenIn.is_cuda != True: + assert (False) + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous() + assert (tenOutgrad.is_cuda == True) + + tenIngrad = tenIn.new_zeros([ + tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3] + ]) if self.needs_input_grad[0] == True else None + tenFlowgrad = tenFlow.new_zeros([ + tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], + tenFlow.shape[3] + ]) if self.needs_input_grad[1] == True else None + + if tenIngrad is not None: + cuda_launch( + cuda_kernel( + 'softsplat_ingrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))(grid=tuple( + [int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenIngrad.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOutgrad.data_ptr(), + tenIngrad.data_ptr(), None + ], + stream=collections.namedtuple('Stream', 'ptr')( + torch.cuda.current_stream().cuda_stream)) + # end + + if tenFlowgrad is not None: + cuda_launch( + cuda_kernel( + 'softsplat_flowgrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))(grid=tuple( + [int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[ + cuda_int32(tenFlowgrad.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOutgrad.data_ptr(), None, + tenFlowgrad.data_ptr() + ], + stream=collections.namedtuple('Stream', 'ptr')( + torch.cuda.current_stream().cuda_stream)) + # end + + return tenIngrad, tenFlowgrad + + # end + + +# end diff --git a/src/pervfiarches/generators/thops.py b/src/pervfiarches/generators/thops.py new file mode 100644 index 00000000..3a4b4bde --- /dev/null +++ b/src/pervfiarches/generators/thops.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd. +# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode +# +# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE + +import torch + + +def sum(tensor, dim=None, keepdim=False): + if dim is None: + # sum up all dim + return torch.sum(tensor) + else: + if isinstance(dim, int): + dim = [dim] + dim = sorted(dim) + for d in dim: + tensor = tensor.sum(dim=d, keepdim=True) + if not keepdim: + for i, d in enumerate(dim): + tensor.squeeze_(d - i) + return tensor + + +def mean(tensor, dim=None, keepdim=False): + if dim is None: + # mean all dim + return torch.mean(tensor) + else: + if isinstance(dim, int): + dim = [dim] + dim = sorted(dim) + for d in dim: + tensor = tensor.mean(dim=d, keepdim=True) + if not keepdim: + for i, d in enumerate(dim): + tensor.squeeze_(d - i) + return tensor + + +def split_feature(tensor, type="split"): + """ + type = ["split", "cross"] + """ + C = tensor.size(1) + if type == "split": + return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] + elif type == "cross": + return tensor[:, 0::2, ...], tensor[:, 1::2, ...] + + +def cat_feature(tensor_a, tensor_b): + return torch.cat((tensor_a, tensor_b), dim=1) + + +def pixels(tensor): + return int(tensor.size(2) * tensor.size(3)) \ No newline at end of file diff --git a/src/pervfiarches/pipeline.py b/src/pervfiarches/pipeline.py new file mode 100644 index 00000000..34f0992d --- /dev/null +++ b/src/pervfiarches/pipeline.py @@ -0,0 +1,88 @@ +import torch + +from .flow_estimators import build_flow_estimator +from .generators import build_generator_arch + + +def get_z(heat: float, img_size: tuple, batch: int, device: str): + def calc_z_shapes(img_size, n_levels): + h, w = img_size + z_shapes = [] + channel = 3 + + for _ in range(n_levels - 1): + h //= 2 + w //= 2 + channel *= 2 + z_shapes.append((channel, h, w)) + h //= 2 + w //= 2 + z_shapes.append((channel * 4, h, w)) + return z_shapes + + z_list = [] + z_shapes = calc_z_shapes(img_size, 3) + for z in z_shapes: + z_new = torch.randn(batch, *z, device=device) * heat + z_list.append(z_new) + return z_list + + +class Pipeline_infer(torch.nn.Module): + def __init__(self, flownet: str, generator: str, model_file: str, flowCheckpoint: str = None, device = "cuda"): + super().__init__() + if flownet is None: + self.flownet = None + else: + self.flownet, self.compute_flow = build_flow_estimator(flownet, device=device, checkpoint=flowCheckpoint) + self.flownet.to(device).eval() + + self.netG = build_generator_arch(generator) + state_dict = { + k.replace("module.", ""): v for k, v in torch.load(model_file, map_location=device).items() + } + self.netG.load_state_dict(state_dict) + self.netG.to(device).eval() + + def forward(self, img0, img1, heat=0.3, time=0.5, flows=None): + if isinstance(heat, float): + zs = get_z(heat, img0.shape[-2:], img0.shape[0], img0.device) + else: + zs = heat + + fflow, bflow = flows if self.flownet is None else self.compute_flow(img0, img1) + conds = [img0, img1, fflow, bflow] + pred, _ = self.netG(zs=zs, inps=conds, time=time, code="decode") + return torch.clamp(pred, 0.0, 1.0) + + @torch.no_grad() + def inference_rand_noise(self, img0, img1, heat=0.7, time=0.5, flows=None): + zs = get_z(heat, img0.shape[-2:], img0.shape[0], img0.device) + fflow, bflow = flows if flows is not None else self.compute_flow(img0, img1) + + conds = [img0, img1, fflow, bflow] + pred, _ = self.netG(zs=zs, inps=conds, time=time, code="decode") + return torch.clamp(pred, 0.0, 1.0) + + @torch.no_grad() + def inference_best_noise(self, img0, img1, gt, time=0.5, flows=None): + fflow, bflow = flows if flows is not None else self.compute_flow(img0, img1) + conds = [img0, img1, fflow, bflow] + _, pred, _ = self.netG(gt=gt, inps=conds, code="encode_decode", time=time) + return torch.clamp(pred, 0.0, 1.0) + + @torch.no_grad() + def inference_spec_noise(self, img0, img1, zs: list, time=0.5, flows=None): + fflow, bflow = flows if flows is not None else self.compute_flow(img0, img1) + conds = [img0, img1, fflow, bflow] + pred, _ = self.netG(zs=zs, inps=conds, code="decode", time=time) + return torch.clamp(pred, 0.0, 1.0) + + @torch.no_grad() + def generate_masks(self, img0, img1, time=0.5): + zs = get_z(0.4, img0.shape[-2:], img0.shape[0], img0.device) + fflow, bflow = self.compute_flow(img0, img1) + + conds = [img0, img1, fflow, bflow] + pred, smasks = self.netG(zs=zs, inps=conds, time=time, code="decode") + return torch.clamp(pred, 0.0, 1.0), smasks diff --git a/src/unifiedInterpolate.py b/src/unifiedInterpolate.py index ce9800eb..50eb60fe 100644 --- a/src/unifiedInterpolate.py +++ b/src/unifiedInterpolate.py @@ -129,7 +129,6 @@ def handle_model(self): ) self.firstRun = True - self.stream = torch.cuda.Stream() if self.sceneChange: @@ -540,6 +539,161 @@ def run(self, frame, interpolateFactor, writeBuffer): if self.sceneChange: self.sceneChangeProcess.cacheFrame() +class PerVFIRaftCuda: + def __init__( + self, + interpolateMethod, + half, + width, + height, + interpolateFactor, + sceneChange, + ): + + self.interpolateMethod = interpolateMethod + self.half = half + self.width = width + self.height = height + self.interpolateFactor = interpolateFactor + self.sceneChange = sceneChange + + self.handleModel() + + def handleModel(self): + # Hardcoded with RAFT for now!!!! + self.interpolateMethod = self.interpolateMethod.lower().split("_")[1] + self.filename = modelsMap(self.interpolateMethod) + if not os.path.exists(os.path.join(weightsDir, self.interpolateMethod, self.filename)): + modelPath = downloadModels(model=self.interpolateMethod) + else: + modelPath = os.path.join(weightsDir, self.interpolateMethod, self.filename) + + if not os.path.exists(os.path.join(weightsDir, "raft", "raft-sintel.pth")): + flowPath = downloadModels(model="raft") + else: + flowPath = os.path.join(weightsDir, "raft", "raft-sintel.pth") + + self.isCudaAvailable = torch.cuda.is_available() + self.device = torch.device("cuda" if self.isCudaAvailable else "cpu") + + # This seems to also handle model loading for me to cuda directly, I will need to modify the arch + # For CPU support and more granurality + from src.pervfiarches.pipeline import Pipeline_infer + match self.interpolateMethod: + case "pervfi_small": + self.model = Pipeline_infer("RAFT", "vb", modelPath, flowPath, self.device) + case "pervfi": + self.model = Pipeline_infer("RAFT", "v00", modelPath, flowPath, self.device) + + + torch.set_grad_enabled(False) + if self.isCudaAvailable: + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + if self.half: + torch.set_default_dtype(torch.float16) + + if self.isCudaAvailable and self.half: + self.model.half() + else: + self.half = False + self.model.float() + + ph = ((self.height - 1) // 8 + 1) * 8 + pw = ((self.width - 1) // 8 + 1) * 8 + self.padding = (0, pw - self.width, 0, ph - self.height) + + + self.I0 = torch.zeros( + 1, + 3, + self.height + self.padding[3], + self.width + self.padding[1], + dtype=torch.float16 if self.half else torch.float32, + device=self.device, + ) + + self.I1 = torch.zeros( + 1, + 3, + self.height + self.padding[3], + self.width + self.padding[1], + dtype=torch.float16 if self.half else torch.float32, + device=self.device, + ) + + self.firstRun = True + self.stream = torch.cuda.Stream() + + if self.sceneChange: + self.sceneChangeProcess = SceneChange(self.half) + + @torch.inference_mode() + def cacheFrame(self): + self.I0.copy_(self.I1, non_blocking=True) + #self.model.cache() + + @torch.inference_mode() + def cacheFrameReset(self): + self.I0.copy_(self.I1, non_blocking=True) + #self.model.cacheReset(self.I0) + + @torch.inference_mode() + def processFrame(self, frame): + return ( + frame.to(self.device, non_blocking=True, dtype=torch.float32 if not self.half else torch.float16) + .permute(2, 0, 1) + .unsqueeze_(0) + .mul_(1 / 255) + .contiguous() + ) + + @torch.inference_mode() + def padFrame(self, frame): + return ( + F.pad(frame, [0, self.padding[1], 0, self.padding[3]]) + if self.padding != (0, 0, 0, 0) + else frame + ) + + @torch.inference_mode() + def run(self, frame, interpolateFactor, writeBuffer): + with torch.cuda.stream(self.stream): + if self.firstRun: + self.I0 = self.padFrame(self.processFrame(frame)) + self.firstRun = False + return + + self.I1 = self.padFrame(self.processFrame(frame)) + + if self.sceneChange: + if self.sceneChangeProcess.run(self.I0, self.I1): + for _ in range(interpolateFactor - 1): + writeBuffer.write(frame) + self.cacheFrameReset() + self.stream.synchronize() + self.sceneChangeProcess.cacheFrame() + return + + for i in range(interpolateFactor - 1): + timestep = torch.full( + (1, 1, self.height + self.padding[3], self.width + self.padding[1]), + (i + 1) * 1 / interpolateFactor, + dtype=torch.float16 if self.half else torch.float32, + device=self.device, + ) + output = self.model(self.I0, self.I1, timestep, interpolateFactor) + output = output[:, :, : self.height, : self.width] + output = output.mul(255.0).squeeze(0).permute(1, 2, 0) + self.stream.synchronize() + writeBuffer.write(output) + + self.cacheFrame() + if self.sceneChange: + self.sceneChangeProcess.cacheFrame() + + + class SceneChange: def __init__( self,