Skip to content

Commit

Permalink
push fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
NevermindNilas committed Sep 27, 2024
1 parent 40abc51 commit abee2ba
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 91 deletions.
40 changes: 30 additions & 10 deletions src/rifearches/Rife420_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import math


from torch.nn.functional import interpolate
Expand Down Expand Up @@ -155,9 +156,28 @@ def __init__(
self.tenFlow = tenFlow

self.blocks = [self.block0, self.block1, self.block2, self.block3]

self.paddedHeight = backWarp.shape[2]
self.paddedWidth = backWarp.shape[3]
tmp = max(64, int(64 / 1.0))
self.pw = math.ceil(self.width / tmp) * tmp
self.ph = math.ceil(self.height / tmp) * tmp
self.padding = (0, self.pw - self.width, 0, self.ph - self.height)
hMul = 2 / (self.pw - 1)
vMul = 2 / (self.ph - 1)
self.tenFlow = (
torch.Tensor([hMul, vMul])
.to(device=self.device, dtype=self.dtype)
.reshape(1, 2, 1, 1)
)
self.backWarp = torch.cat(
(
(torch.arange(self.pw) * hMul - 1)
.reshape(1, 1, 1, -1)
.expand(-1, -1, self.ph, -1),
(torch.arange(self.ph) * vMul - 1)
.reshape(1, 1, -1, 1)
.expand(-1, -1, -1, self.pw),
),
dim=1,
).to(device=self.device, dtype=self.dtype)

def forward(self, img0, img1, timestep, f0):
imgs = torch.cat([img0, img1], dim=1)
Expand Down Expand Up @@ -197,7 +217,7 @@ def forward(self, img0, img1, timestep, f0):
temp = torch.cat(
(
wimg, # noqa
wf, # noqa
wf, # noqa
timestep,
mask,
(flows * (1 / scale) if scale != 1 else flows),
Expand All @@ -206,8 +226,8 @@ def forward(self, img0, img1, timestep, f0):
)
temp_ = torch.cat(
(
wimg_rev, # noqa
wf_rev, # noqa
wimg_rev, # noqa
wf_rev, # noqa
1 - timestep,
-mask,
(flows_rev * (1 / scale) if scale != 1 else flows_rev),
Expand All @@ -224,8 +244,8 @@ def forward(self, img0, img1, timestep, f0):
else:
temp = torch.cat(
(
wimg, # noqa
wf, # noqa
wimg, # noqa
wf, # noqa
timestep,
mask,
(flows * (1 / scale) if scale != 1 else flows),
Expand Down Expand Up @@ -271,8 +291,8 @@ def forward(self, img0, img1, timestep, f0):
wimg = torch.reshape(wimg, (1, 6, self.paddedHeight, self.paddedWidth))
wf = torch.reshape(wf, (1, 16, self.paddedHeight, self.paddedWidth))
if self.ensemble:
wimg_rev = torch.cat(torch.split(wimg, [3, 3], dim=1)[::-1], dim=1) # noqa
wf_rev = torch.cat(torch.split(wf, [8, 8], dim=1)[::-1], dim=1) # noqa
wimg_rev = torch.cat(torch.split(wimg, [3, 3], dim=1)[::-1], dim=1) # noqa
wf_rev = torch.cat(torch.split(wf, [8, 8], dim=1)[::-1], dim=1) # noqa
mask = torch.sigmoid(mask)
warped_img0, warped_img1 = torch.split(warped_imgs, [1, 1])
return (
Expand Down
45 changes: 29 additions & 16 deletions src/rifearches/Rife422_lite_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn

import math

from torch.nn.functional import interpolate

Expand Down Expand Up @@ -138,8 +138,6 @@ def __init__(
device="cuda",
width=1920,
height=1080,
backWarp=None,
tenFlow=None,
):
super(IFNet, self).__init__()
self.block0 = IFBlock(7 + 8, c=192)
Expand All @@ -153,19 +151,37 @@ def __init__(
self.ensemble = ensemble
self.width = width
self.height = height
self.backWarp = backWarp
self.tenFlow = tenFlow
self.blocks = [self.block0, self.block1, self.block2, self.block3]

self.paddedHeight = backWarp.shape[2]
self.paddedWidth = backWarp.shape[3]
self.blocks = [self.block0, self.block1, self.block2, self.block3]
tmp = max(64, int(64 / 1.0))
self.pw = math.ceil(self.width / tmp) * tmp
self.ph = math.ceil(self.height / tmp) * tmp
self.padding = (0, self.pw - self.width, 0, self.ph - self.height)
hMul = 2 / (self.pw - 1)
vMul = 2 / (self.ph - 1)
self.tenFlow = (
torch.Tensor([hMul, vMul])
.to(device=self.device, dtype=self.dtype)
.reshape(1, 2, 1, 1)
)
self.backWarp = torch.cat(
(
(torch.arange(self.pw) * hMul - 1)
.reshape(1, 1, 1, -1)
.expand(-1, -1, self.ph, -1),
(torch.arange(self.ph) * vMul - 1)
.reshape(1, 1, -1, 1)
.expand(-1, -1, -1, self.pw),
),
dim=1,
).to(device=self.device, dtype=self.dtype)

def forward(self, img0, img1, timestep, f0):
imgs = torch.cat([img0, img1], dim=1)
imgs_2 = torch.reshape(imgs, (2, 3, self.paddedHeight, self.paddedWidth))
imgs_2 = torch.reshape(imgs, (2, 3, self.ph, self.pw))
f1 = self.encode(img1[:, :3])
fs = torch.cat([f0, f1], dim=1)
fs_2 = torch.reshape(fs, (2, 8, self.paddedHeight, self.paddedWidth))
fs_2 = torch.reshape(fs, (2, 4, self.ph, self.pw))
warped_img0 = img0
warped_img1 = img1
flows = None
Expand All @@ -186,13 +202,10 @@ def forward(self, img0, img1, timestep, f0):
1,
)
fds, mask, feat = block(temp, scale=scale)

flows = flows + fds

precomp = (
self.backWarp
+ flows.reshape((2, 2, self.paddedHeight, self.paddedWidth))
* self.tenFlow
self.backWarp + flows.reshape((2, 2, self.ph, self.pw)) * self.tenFlow
).permute(0, 2, 3, 1)
if scale == 1:
warped_imgs = torch.nn.functional.grid_sample(
Expand All @@ -211,8 +224,8 @@ def forward(self, img0, img1, timestep, f0):
align_corners=True,
)
wimg, wf = torch.split(warps, [3, 4], dim=1)
wimg = torch.reshape(wimg, (1, 6, self.paddedHeight, self.paddedWidth))
wf = torch.reshape(wf, (1, 8, self.paddedHeight, self.paddedWidth))
wimg = torch.reshape(wimg, (1, 6, self.ph, self.pw))
wf = torch.reshape(wf, (1, 8, self.ph, self.pw))

mask = torch.sigmoid(mask)
warped_img0, warped_img1 = torch.split(warped_imgs, [1, 1])
Expand Down
2 changes: 0 additions & 2 deletions src/rifearches/Rife422_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ def __init__(
device="cuda",
width=1920,
height=1080,
backWarp=None,
tenFlow=None,
):
super(IFNet, self).__init__()
self.block0 = IFBlock(7 + 16, c=256)
Expand Down
Loading

0 comments on commit abee2ba

Please sign in to comment.