Skip to content

Commit

Permalink
theoretical support for rife 4.17-lite
Browse files Browse the repository at this point in the history
  • Loading branch information
NevermindNilas committed Jun 17, 2024
1 parent 681041a commit 80160dc
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 6 deletions.
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,14 @@ def start(self):
"rife4.15",
"rife4.15-lite",
"rife4.16-lite",
"rife4.17",
"rife4.17-lite",
"rife-ncnn",
"rife4.6-ncnn",
"rife4.15-ncnn",
"rife4.15-lite-ncnn",
"rife4.16-lite-ncnn",
"rife4.17-ncnn",
"rife4.17",
"rife4.6-tensorrt",
"rife4.15-tensorrt",
"rife4.15-lite-tensorrt",
Expand Down
3 changes: 3 additions & 0 deletions src/downloadModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def modelsMap(
case "rife" | "rife4.17":
return "rife417.pth"

case "rife4.17-lite":
return "rife417_lite.pth"

case "rife4.15":
return "rife415.pth"

Expand Down
1 change: 1 addition & 0 deletions src/initializeModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def initializeModels(self):
| "rife4.15-lite"
| "rife4.16-lite"
| "rife4.17"
| "rife4.17-lite"
):
from src.unifiedInterpolate import RifeCuda

Expand Down
1 change: 0 additions & 1 deletion src/rifearches/IFNet_rife416lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
nn.LeakyReLU(0.2, True),
)


def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
Expand Down
223 changes: 223 additions & 0 deletions src/rifearches/IFNet_rife417lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .warplayer import warp


def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True,
),
nn.LeakyReLU(0.2, True),
)


def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.2, True),
)


class Head(nn.Module):
def __init__(self):
super(Head, self).__init__()
self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1)
self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1)
self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1)
self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1)
self.relu = nn.LeakyReLU(0.2, True)

def forward(self, x, feat=False):
x0 = self.cnn0(x)
x = self.relu(x0)
x1 = self.cnn1(x)
x = self.relu(x1)
x2 = self.cnn2(x)
x = self.relu(x2)
x3 = self.cnn3(x)
if feat:
return [x0, x1, x2, x3]
return x3


class ResConv(nn.Module):
def __init__(self, c, dilation=1):
super(ResConv, self).__init__()
self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
self.relu = nn.LeakyReLU(0.2, True)

def forward(self, x):
return self.relu(self.conv(x) * self.beta + x)


class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c // 2, 3, 2, 1),
conv(c // 2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
)
self.lastconv = nn.Sequential(
nn.ConvTranspose2d(c, 4 * 6, 4, 2, 1), nn.PixelShuffle(2)
)

def forward(self, x, flow=None, scale=1):
x = F.interpolate(
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
)
if flow is not None:
flow = (
F.interpolate(
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False
)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1)
feat = self.conv0(x)
feat = self.convblock(feat)
tmp = self.lastconv(feat)
tmp = F.interpolate(
tmp, scale_factor=scale, mode="bilinear", align_corners=False
)
flow = tmp[:, :4] * scale
mask = tmp[:, 4:5]
return flow, mask


class IFNet(nn.Module):
def __init__(self, ensemble=False, scale=1):
super(IFNet, self).__init__()
self.block0 = IFBlock(7 + 8, c=128)
self.block1 = IFBlock(8 + 4 + 8, c=96)
self.block2 = IFBlock(8 + 4 + 8, c=64)
self.block3 = IFBlock(8 + 4 + 8, c=48)
self.encode = Head()
self.f0 = None
self.f1 = None
self.scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
self.ensemble = ensemble
self.counter = 1

def cache(self):
self.f0.copy_(self.f1, non_blocking=True)

def cacheReset(self, frame):
self.f0 = self.encode(frame[:, :3])

def forward(self, img0, img1, timestep, interpolateFactor=2):
# Overengineered but it seems to work
if interpolateFactor == 2:
if self.f0 is None:
self.f0 = self.encode(img0[:, :3])
self.f1 = self.encode(img1[:, :3])
else:
if self.counter == interpolateFactor:
self.counter = 1
if self.f0 is None:
self.f0 = self.encode(img0[:, :3])
self.f1 = self.encode(img1[:, :3])
else:
if self.f0 is None or self.f1 is None:
self.f0 = self.encode(img0[:, :3])
self.f1 = self.encode(img1[:, :3])
self.counter += 1

merged = []
warped_img0 = img0
warped_img1 = img1
flow = None
block = [self.block0, self.block1, self.block2, self.block3]
for i in range(4):
if flow is None:
flow, mask = block[i](
torch.cat(
(img0[:, :3], img1[:, :3], self.f0, self.f1, timestep), 1
),
None,
scale=self.scale_list[i],
)
if self.ensemble:
f_, m_ = block[i](
torch.cat(
(img1[:, :3], img0[:, :3], self.f1, self.f0, 1 - timestep),
1,
),
None,
scale=self.scale_list[i],
)
flow = (flow + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (mask + (-m_)) / 2
else:
wf0 = warp(self.f0, flow[:, :2])
wf1 = warp(self.f1, flow[:, 2:4])
fd, m0 = block[i](
torch.cat(
(
warped_img0[:, :3],
warped_img1[:, :3],
wf0,
wf1,
timestep,
mask,
),
1,
),
flow,
scale=self.scale_list[i],
)
if self.ensemble:
f_, m_ = block[i](
torch.cat(
(
warped_img1[:, :3],
warped_img0[:, :3],
wf1,
wf0,
1 - timestep,
-mask,
),
1,
),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
scale=self.scale_list[i],
)
fd = (fd + torch.cat((f_[:, 2:4], f_[:, :2]), 1)) / 2
mask = (m0 + (-m_)) / 2
else:
mask = m0
flow = flow + fd
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append((warped_img0, warped_img1))
mask = torch.sigmoid(mask)
merged[3] = warped_img0 * mask + warped_img1 * (1 - mask)
return merged[3]
11 changes: 7 additions & 4 deletions src/unifiedInterpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def handle_model(self):
match self.interpolateMethod:
case "rife" | "rife4.17":
from .rifearches.IFNet_rife417 import IFNet
case "rife4.17-lite":
from .rifearches.IFNet_rife417lite import IFNet
case "rife4.15":
from .rifearches.IFNet_rife415 import IFNet
case "rife4.15-lite":
Expand All @@ -100,6 +102,7 @@ def handle_model(self):
if self.isCudaAvailable and self.half:
self.model.half()
else:
self.half = False
self.model.float()

self.model.load_state_dict(torch.load(modelPath, map_location=self.device))
Expand Down Expand Up @@ -129,7 +132,7 @@ def handle_model(self):
)

self.firstRun = True
self.stream = torch.cuda.Stream()
self.stream = torch.cuda.Stream() if self.isCudaAvailable else None

if self.sceneChange:
self.sceneChangeProcess = SceneChange(self.half)
Expand Down Expand Up @@ -164,7 +167,7 @@ def padFrame(self, frame):

@torch.inference_mode()
def run(self, frame, interpolateFactor, writeBuffer):
with torch.cuda.stream(self.stream):
with torch.cuda.stream(self.stream) if self.isCudaAvailable else torch.no_grad():
if self.firstRun:
self.I0 = self.padFrame(self.processFrame(frame))
self.firstRun = False
Expand All @@ -177,7 +180,7 @@ def run(self, frame, interpolateFactor, writeBuffer):
for _ in range(interpolateFactor - 1):
writeBuffer.write(frame)
self.cacheFrameReset()
self.stream.synchronize()
self.stream.synchronize() if self.isCudaAvailable else None
self.sceneChangeProcess.cacheFrame()
return

Expand All @@ -191,7 +194,7 @@ def run(self, frame, interpolateFactor, writeBuffer):
output = self.model(self.I0, self.I1, timestep, interpolateFactor)
output = output[:, :, : self.height, : self.width]
output = output.mul(255.0).squeeze(0).permute(1, 2, 0)
self.stream.synchronize()
self.stream.synchronize() if self.isCudaAvailable else None
writeBuffer.write(output)

self.cacheFrame()
Expand Down

0 comments on commit 80160dc

Please sign in to comment.