Skip to content

Commit

Permalink
add pervfi interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
NevermindNilas committed Jun 14, 2024
1 parent f100458 commit 5ff2bf5
Show file tree
Hide file tree
Showing 50 changed files with 8,776 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ Both internal and user generated benchmarks can be found [here](BENCHMARKS.MD).
| [cszn](https://github.com/cszn/DPIR) | DPIR |
| [TNTWise](https://github.com/TNTwise) | For Rife ONNX and NCNN models |
| [WolframRhodium](https://github.com/WolframRhodium) | For Rife V2 models |
| [mulns](https://github.com/mulns/PerVFI) | For PerVFI |

## 🌟 Star History

Expand Down
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def start(self):
"rife4.17-tensorrt",
"rife-tensorrt",
"gmfss",
"raft_pervfi_lite",
"raft_pervfi",
],
default="rife",
)
Expand Down
1 change: 1 addition & 0 deletions requirements-linux.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ GPUtil
spandrel
yt-dlp
requests
accelerate
tqdm
tensorrt
scikit-image
Expand Down
1 change: 1 addition & 0 deletions requirements-windows.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ kornia
scipy
wmi
spandrel
accelerate
yt-dlp
requests
tqdm
Expand Down
12 changes: 12 additions & 0 deletions src/downloadModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def modelsList() -> list[str]:
"small-directml",
"base-directml",
"large-directml",
"pervfi_lite",
"pervfi",
]


Expand Down Expand Up @@ -327,6 +329,16 @@ def modelsMap(
return "maxxvitv2_rmlp_base_rw_224.sw_in12k_b80_224px_20k_coloraug0.4_6ch_clamp_softmax_fp16_op17_onnxslim.onnx"
else:
return "maxxvitv2_rmlp_base_rw_224.sw_in12k_b80_224px_20k_coloraug0.4_6ch_clamp_softmax_op17_onnxslim.onnx"

case "pervfi_lite":
return "pervfi_lite.pth"

case "pervfi":
return "pervfi.pth"

case "raft":
return "raft-sintel.pth"

case _:
raise ValueError(f"Model {model} not found.")

Expand Down
16 changes: 16 additions & 0 deletions src/initializeModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,22 @@ def initializeModels(self):
self.scenechange,
)

case (
"raft_pervfi_lite"
| "raft_pervfi"
):

from src.unifiedInterpolate import PerVFIRaftCuda

interpolate_process = PerVFIRaftCuda(
self.interpolate_method,
self.half,
outputWidth,
outputHeight,
self.interpolate_factor,
self.scenechange,
)

if self.denoise:
match self.denoise_method:
case "scunet" | "dpir" | "nafnet":
Expand Down
166 changes: 166 additions & 0 deletions src/pervfiarches/flow_estimators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import torch


class InputPadder:
"""Pads images such that dimensions are divisible by factor"""

def __init__(self, size, divide=8, mode="center"):
self.ht, self.wd = size[-2:]
pad_ht = (((self.ht // divide) + 1) * divide - self.ht) % divide
pad_wd = (((self.wd // divide) + 1) * divide - self.wd) % divide
if mode == "center":
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
]
else:
self._pad = [0, pad_wd, 0, pad_ht]

def _pad_(self, x):
return torch.nn.functional.pad(x, self._pad, mode="constant")

def pad(self, *inputs):
return [self._pad_(x) for x in inputs]

def _unpad_(self, x):
return x[
...,
self._pad[2] : self.ht + self._pad[2],
self._pad[0] : self.wd + self._pad[0],
]

def unpad(self, *inputs):
return [self._unpad_(x) for x in inputs]


def build_flow_estimator(name, device="cuda", checkpoint=None):
if name.lower() == "raft":
import argparse

from .raft.raft import RAFT

args = argparse.Namespace(
mixed_precision=True, alternate_corr=False, small=False
)
model = RAFT(args)

if checkpoint is None:
ckpt = "checkpoints/RAFT/raft-sintel.pth"
else:
ckpt = checkpoint # Path to the checkpoint .pth, modified from original arch for better comp

model.load_state_dict(
{k.replace("module.", ""): v for k, v in torch.load(ckpt, map_location=device).items()}
)
model.to(device).eval()

@torch.no_grad()
def infer(I1, I2):
I1 = I1.to(device) * 255.0
I2 = I2.to(device) * 255.0
padder = InputPadder(I1.shape, 8)
I1, I2 = padder.pad(I1, I2)
fflow = model(I1, I2, bidirection=False, iters=12)
bflow = model(I2, I1, bidirection=False, iters=12)
return padder.unpad(fflow, bflow)

if name.lower() == "raft_small":
import argparse

from .raft.raft import RAFT

args = argparse.Namespace(
mixed_precision=True, alternate_corr=False, small=True
)
model = RAFT(args)
ckpt = "checkpoints/RAFT/raft-small.pth"
model.load_state_dict(
{k.replace("module.", ""): v for k, v in torch.load(ckpt).items()}
)
model.to(device).eval()

@torch.no_grad()
def infer(I1, I2):
I1 = I1.to(device) * 255.0
I2 = I2.to(device) * 255.0
padder = InputPadder(I1.shape, 8)
I1, I2 = padder.pad(I1, I2)
fflow = model(I1, I2, bidirection=False, iters=12)
bflow = model(I2, I1, bidirection=False, iters=12)
return padder.unpad(fflow, bflow)

if name.lower() == "gma":
import argparse

from .gma.network import RAFTGMA

args = argparse.Namespace(
mixed_precision=True,
num_heads=1,
position_only=False,
position_and_content=False,
)
model = RAFTGMA(args)
ckpt = "checkpoints/GMA/gma-sintel.pth"
model.load_state_dict(
{k.replace("module.", ""): v for k, v in torch.load(ckpt).items()}
)
model.to(device).eval()

@torch.no_grad()
def infer(I1, I2):
I1 = I1.to(device) * 255.0
I2 = I2.to(device) * 255.0
padder = InputPadder(I1.shape, 8)
I1, I2 = padder.pad(I1, I2)
_, fflow = model(I1, I2, test_mode=True, iters=20)
_, bflow = model(I2, I1, test_mode=True, iters=20)
return padder.unpad(fflow, bflow)

if name.lower() == "gmflow":
from .gmflow.gmflow import GMFlow

model = GMFlow(
feature_channels=128,
num_scales=1,
upsample_factor=8,
num_head=1,
attention_type="swin",
ffn_dim_expansion=4,
num_transformer_layers=6,
)
ckpt = "checkpoints/GMFlow/gmflow_sintel-0c07dcb3.pth"
model.load_state_dict(torch.load(ckpt)["model"])
model.to(device).eval()

@torch.no_grad()
def infer(I1, I2):
I1 = I1.to(device) * 255.0
I2 = I2.to(device) * 255.0
padder = InputPadder(I1.shape, 16)
I1, I2 = padder.pad(I1, I2)
results_dict = model(
I1,
I2,
attn_splits_list=[2],
corr_radius_list=[-1],
prop_radius_list=[-1],
pred_bidir_flow=False,
)
fflow = results_dict["flow_preds"][-1]
results_dict = model(
I2,
I1,
attn_splits_list=[2],
corr_radius_list=[-1],
prop_radius_list=[-1],
pred_bidir_flow=False,
)
bflow = results_dict["flow_preds"][-1]

fflow, bflow = padder.unpad(fflow, bflow)
return fflow, bflow

return model, infer
106 changes: 106 additions & 0 deletions src/pervfiarches/flow_estimators/gma/corr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils.utils import bilinear_sampler, coords_grid

# from compute_sparse_correlation import compute_sparse_corr, compute_sparse_corr_torch, compute_sparse_corr_mink

try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass


class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []

# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)

batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)

self.corr_pyramid.append(corr)
for i in range(self.num_levels - 1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)

def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape

out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)

centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl

corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)

out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()

@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht * wd)
fmap2 = fmap2.view(batch, dim, ht * wd)

corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())


class CorrBlockSingleScale(nn.Module):
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
super().__init__()
self.radius = radius

# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
self.corr = corr.reshape(batch * h1 * w1, dim, h2, w2)

def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape

corr = self.corr
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)

centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2)
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl

corr = bilinear_sampler(corr, coords_lvl)
out = corr.view(batch, h1, w1, -1)
out = out.permute(0, 3, 1, 2).contiguous().float()
return out

@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht * wd)
fmap2 = fmap2.view(batch, dim, ht * wd)

corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
Loading

0 comments on commit 5ff2bf5

Please sign in to comment.