From ccf11090a24c6c4e7c81ac313df7cab8247f95d3 Mon Sep 17 00:00:00 2001 From: NevermindNilas Date: Tue, 4 Jun 2024 03:34:39 +0300 Subject: [PATCH] add Scene Change Detection --- TheAnimeScripter.jsx | 2 +- gui.py | 2 +- main.py | 14 +- src/gmfss/gmfss_fortuna_union.py | 77 ++++--- src/initializeModels.py | 14 +- src/rifearches/IFNet_rife415.py | 11 +- src/rifearches/IFNet_rife415lite.py | 10 +- src/rifearches/IFNet_rife416lite.py | 10 +- src/rifearches/IFNet_rife417.py | 13 +- src/rifearches/IFNet_rife46.py | 6 + src/scenechange/scenechange.py | 79 ------- src/unifiedInterpolate.py | 308 +++++++++++++++++++++++----- src/unifiedUpscale.py | 5 +- 13 files changed, 356 insertions(+), 195 deletions(-) delete mode 100644 src/scenechange/scenechange.py diff --git a/TheAnimeScripter.jsx b/TheAnimeScripter.jsx index b1635455..5a5d3685 100644 --- a/TheAnimeScripter.jsx +++ b/TheAnimeScripter.jsx @@ -2,7 +2,7 @@ var panelGlobal = this; var TheAnimeScripter = (function () { var scriptName = "TheAnimeScripter"; - var scriptVersion = "v1.8.1"; + var scriptVersion = "v1.8.2"; // Default Values for the settings var outputFolder = app.settings.haveSetting(scriptName, "outputFolder") ? app.settings.getSetting(scriptName, "outputFolder") : "undefined"; diff --git a/gui.py b/gui.py index 7da929e8..b6658074 100644 --- a/gui.py +++ b/gui.py @@ -37,7 +37,7 @@ format='%(asctime)s %(levelname)s %(name)s %(message)s') logger=logging.getLogger(__name__) -TITLE = "The Anime Scripter - 1.8.1 (Alpha)" +TITLE = "The Anime Scripter - 1.8.2 (Alpha)" W, H = 1280, 720 if getattr(sys, "frozen", False): diff --git a/main.py b/main.py index 2657cb86..b2849588 100644 --- a/main.py +++ b/main.py @@ -39,7 +39,7 @@ else: mainPath = os.path.dirname(os.path.abspath(__file__)) -scriptVersion = "1.8.1" +scriptVersion = "1.8.2" warnings.filterwarnings("ignore") @@ -141,8 +141,7 @@ def __init__(self, args): def processFrame(self, frame): try: if self.dedup and self.dedup_method != "ffmpeg": - result = self.dedup_process.run(frame) - if result: + if self.dedup_process.run(frame): self.dedupCount += 1 return @@ -153,13 +152,6 @@ def processFrame(self, frame): frame = self.upscale_process.run(frame) if self.interpolate: - if self.scenechange: - result = self.scenechange_process.run(frame) - if result: - for _ in (self.interpolate_factor - 1): - self.writeBuffer.write(frame) - return - self.interpolate_process.run( frame, self.interpolate_factor, self.writeBuffer ) @@ -500,7 +492,7 @@ def start(self): videoFiles = [ os.path.join(args.input, file) for file in os.listdir(args.input) - if file.endswith((".mp4", ".mkv", ".mov", ".avi")) + if file.endswith((".mp4", ".mkv", ".mov", ".avi", ".webm")) ] toPrint = f"Processing {len(videoFiles)} files" logging.info(toPrint) diff --git a/src/gmfss/gmfss_fortuna_union.py b/src/gmfss/gmfss_fortuna_union.py index 8e9ffe39..f2d50189 100644 --- a/src/gmfss/gmfss_fortuna_union.py +++ b/src/gmfss/gmfss_fortuna_union.py @@ -13,7 +13,14 @@ class GMFSS: def __init__( - self, interpolation_factor, half, width, height, ensemble=False, nt=1 + self, + interpolation_factor, + half, + width, + height, + ensemble=False, + nt=1, + sceneChange=False, ): self.width = width self.height = height @@ -21,13 +28,18 @@ def __init__( self.interpolation_factor = interpolation_factor self.ensemble = ensemble self.nt = nt + self.sceneChange = sceneChange ph = ((self.height - 1) // 32 + 1) * 32 pw = ((self.width - 1) // 32 + 1) * 32 self.padding = (0, pw - self.width, 0, ph - self.height) if self.width > 1920 or self.height > 1080: - print(yellow("Warning: Output Resolution is higher than 1080p. Expect significant slowdowns or no functionality at all due to VRAM Constraints when using GMFSS, in case of issues consider switching to RIFE.")) + print( + yellow( + "Warning: Output Resolution is higher than 1080p. Expect significant slowdowns or no functionality at all due to VRAM Constraints when using GMFSS, in case of issues consider switching to RIFE." + ) + ) self.scale = 0.5 else: self.scale = 1 @@ -58,8 +70,8 @@ def handle_model(self): torch.set_grad_enabled(False) if self.isCudaAvailable: - #self.stream = [torch.cuda.Stream() for _ in range(self.nt)] - #self.current_stream = 0 + # self.stream = [torch.cuda.Stream() for _ in range(self.nt)] + # self.current_stream = 0 torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True if self.half: @@ -80,9 +92,7 @@ def handle_model(self): 3, self.height + self.padding[3], self.width + self.padding[1], - dtype=torch.float16 - if self.half - else torch.float32, + dtype=torch.float16 if self.half else torch.float32, device=self.device, ) @@ -91,16 +101,18 @@ def handle_model(self): 3, self.height + self.padding[3], self.width + self.padding[1], - dtype=torch.float16 - if self.half - else torch.float32, + dtype=torch.float16 if self.half else torch.float32, device=self.device, ) self.stream = torch.cuda.Stream() self.firstRun = True - @torch.inference_mode() + if self.sceneChange: + from src.unifiedInterpolate import SceneChange + self.sceneChangeProcess = SceneChange(self.half) + + @torch.inference_mode() def make_inference(self, n): """ if self.isCudaAvailable: @@ -114,28 +126,26 @@ def make_inference(self, n): ) output = self.model(self.I0, self.I1, timestep) - #if self.isCudaAvailable: - #torch.cuda.synchronize(self.stream[self.current_stream]) - #self.current_stream = (self.current_stream + 1) % len(self.stream) + # if self.isCudaAvailable: + # torch.cuda.synchronize(self.stream[self.current_stream]) + # self.current_stream = (self.current_stream + 1) % len(self.stream) if self.padding != (0, 0, 0, 0): output = output[..., : self.height, : self.width] - + return output.squeeze(0).permute(1, 2, 0).mul_(255) + + @torch.inference_mode() + def cacheFrame(self): + self.I0.copy_(self.I1, non_blocking=True) @torch.inference_mode() def processFrame(self, frame): return ( ( - frame.to(self.device) - .permute(2, 0, 1) - .unsqueeze(0) - .float() + frame.to(self.device).permute(2, 0, 1).unsqueeze(0).float() if not self.half - else frame.to(self.device) - .permute(2, 0, 1) - .unsqueeze(0) - .half() + else frame.to(self.device).permute(2, 0, 1).unsqueeze(0).half() ) .mul(1 / 255) .contiguous() @@ -143,7 +153,11 @@ def processFrame(self, frame): @torch.inference_mode() def padFrame(self, frame): - return F.pad(frame, [0, self.padding[1], 0, self.padding[3]]) if self.padding != (0, 0, 0, 0) else frame + return ( + F.pad(frame, [0, self.padding[1], 0, self.padding[3]]) + if self.padding != (0, 0, 0, 0) + else frame + ) @torch.inference_mode() def run(self, frame, interpolateFactor, writeBuffer): @@ -156,18 +170,25 @@ def run(self, frame, interpolateFactor, writeBuffer): self.I1 = self.processFrame(frame) self.I1 = self.padFrame(self.I1) + if self.sceneChange: + if self.sceneChangeProcess.run(self.I0, self.I1): + for _ in range(interpolateFactor - 1): + writeBuffer.write(frame) + self.cacheFrame() + self.stream.synchronize() + return + for i in range(interpolateFactor - 1): timestep = torch.tensor( (i + 1) * 1.0 / self.interpolation_factor, dtype=self.dtype, device=self.device, ) - output = self.model( - self.I0, self.I1, timestep - ) + output = self.model(self.I0, self.I1, timestep) output = output[:, :, : self.height, : self.width] output = output.mul(255.0).squeeze(0).permute(1, 2, 0) self.stream.synchronize() writeBuffer.write(output) + + self.cacheFrame() - self.I0.copy_(self.I1, non_blocking=True) diff --git a/src/initializeModels.py b/src/initializeModels.py index f809cd12..0effbb9e 100644 --- a/src/initializeModels.py +++ b/src/initializeModels.py @@ -215,6 +215,7 @@ def initializeModels(self): self.ensemble, self.nt, self.interpolate_factor, + self.scenechange, ) case "gmfss": from src.gmfss.gmfss_fortuna_union import GMFSS @@ -226,6 +227,7 @@ def initializeModels(self): outputHeight, self.ensemble, self.nt, + self.scenechange, ) case ( @@ -236,7 +238,7 @@ def initializeModels(self): | "rife4.16-lite-ncnn" | "rife4.17-ncnn" ): - from src.rifencnn.rifencnn import rifeNCNN + from src.unifiedInterpolate import rifeNCNN interpolate_process = rifeNCNN( self.interpolate_method, @@ -244,6 +246,8 @@ def initializeModels(self): self.nt, outputWidth, outputHeight, + self.scenechange, + self.half, ) case ( @@ -263,6 +267,7 @@ def initializeModels(self): self.half, self.ensemble, self.nt, + self.scenechange, ) if self.denoise: @@ -317,13 +322,6 @@ def initializeModels(self): # case ffmpeg, ffmpeg works on decode, refer to ffmpegSettings.py ReadBuffer class. - if self.scenechange: - from src.scenechange.scenechange import SceneChange - - scenechange_process = SceneChange( - self.half, - ) - return ( outputWidth, outputHeight, diff --git a/src/rifearches/IFNet_rife415.py b/src/rifearches/IFNet_rife415.py index 63f5d1f2..a0ac4ce3 100644 --- a/src/rifearches/IFNet_rife415.py +++ b/src/rifearches/IFNet_rife415.py @@ -127,13 +127,17 @@ def __init__(self, ensemble=False, scale=1): self.ensemble = ensemble self.counter = 1 + def cache(self): + self.f0.copy_(self.f1, non_blocking=True) + + def cacheReset(self, frame): + self.f0 = self.encode(frame[:, :3]) + def forward(self, img0, img1, timestep, interpolateFactor = 2): # Overengineered but it seems to work if interpolateFactor == 2: if self.f0 is None: self.f0 = self.encode(img0[:, :3]) - else: - self.f0.copy_(self.f1, non_blocking=True) self.f1 = self.encode(img1[:, :3]) else: @@ -141,8 +145,7 @@ def forward(self, img0, img1, timestep, interpolateFactor = 2): self.counter = 1 if self.f0 is None: self.f0 = self.encode(img0[:, :3]) - else: - self.f0.copy_(self.f1, non_blocking=True) + self.f1 = self.encode(img1[:, :3]) else: if self.f0 is None or self.f1 is None: diff --git a/src/rifearches/IFNet_rife415lite.py b/src/rifearches/IFNet_rife415lite.py index 59e9012e..40365338 100644 --- a/src/rifearches/IFNet_rife415lite.py +++ b/src/rifearches/IFNet_rife415lite.py @@ -127,13 +127,17 @@ def __init__(self, ensemble, scale=1): self.ensemble = ensemble self.counter = 1 + def cache(self): + self.f0.copy_(self.f1, non_blocking=True) + + def cacheReset(self, frame): + self.f0 = self.encode(frame[:, :3]) + def forward(self, img0, img1, timestep, interpolateFactor = 2): # Overengineered but it seems to work if interpolateFactor == 2: if self.f0 is None: self.f0 = self.encode(img0[:, :3]) - else: - self.f0.copy_(self.f1, non_blocking=True) self.f1 = self.encode(img1[:, :3]) else: @@ -141,8 +145,6 @@ def forward(self, img0, img1, timestep, interpolateFactor = 2): self.counter = 1 if self.f0 is None: self.f0 = self.encode(img0[:, :3]) - else: - self.f0.copy_(self.f1, non_blocking=True) self.f1 = self.encode(img1[:, :3]) else: if self.f0 is None or self.f1 is None: diff --git a/src/rifearches/IFNet_rife416lite.py b/src/rifearches/IFNet_rife416lite.py index cb6619cc..abff93b2 100644 --- a/src/rifearches/IFNet_rife416lite.py +++ b/src/rifearches/IFNet_rife416lite.py @@ -126,13 +126,17 @@ def __init__(self, ensemble=False, scale=1): self.ensemble = ensemble self.counter = 1 + def cache(self): + self.f0.copy_(self.f1, non_blocking=True) + + def cacheReset(self, frame): + self.f0 = self.encode(frame[:, :3]) + def forward(self, img0, img1, timestep, interpolateFactor = 2): # Overengineered but it seems to work if interpolateFactor == 2: if self.f0 is None: self.f0 = self.encode(img0[:, :3]) - else: - self.f0.copy_(self.f1, non_blocking=True) self.f1 = self.encode(img1[:, :3]) else: @@ -140,8 +144,6 @@ def forward(self, img0, img1, timestep, interpolateFactor = 2): self.counter = 1 if self.f0 is None: self.f0 = self.encode(img0[:, :3]) - else: - self.f0.copy_(self.f1, non_blocking=True) self.f1 = self.encode(img1[:, :3]) else: if self.f0 is None or self.f1 is None: diff --git a/src/rifearches/IFNet_rife417.py b/src/rifearches/IFNet_rife417.py index 73c63615..62ae7ea9 100644 --- a/src/rifearches/IFNet_rife417.py +++ b/src/rifearches/IFNet_rife417.py @@ -127,22 +127,23 @@ def __init__(self, ensemble=False, scale=1): self.ensemble = ensemble self.counter = 1 + def cache(self): + self.f0.copy_(self.f1, non_blocking=True) + + def cacheReset(self, frame): + self.f0 = self.encode(frame[:, :3]) + def forward(self, img0, img1, timestep, interpolateFactor = 2): # Overengineered but it seems to work if interpolateFactor == 2: if self.f0 is None: - self.f0 = self.encode(img0[:, :3]) - else: - self.f0.copy_(self.f1, non_blocking=True) - + self.f0 = self.encode(img0[:, :3]) self.f1 = self.encode(img1[:, :3]) else: if self.counter == interpolateFactor: self.counter = 1 if self.f0 is None: self.f0 = self.encode(img0[:, :3]) - else: - self.f0.copy_(self.f1, non_blocking=True) self.f1 = self.encode(img1[:, :3]) else: if self.f0 is None or self.f1 is None: diff --git a/src/rifearches/IFNet_rife46.py b/src/rifearches/IFNet_rife46.py index 0e19fc6a..d8eb0125 100644 --- a/src/rifearches/IFNet_rife46.py +++ b/src/rifearches/IFNet_rife46.py @@ -102,6 +102,12 @@ def __init__(self, ensemble=False, scale=1): self.scale_list=[8/scale, 4/scale, 2/scale, 1/scale] self.ensemble = ensemble + def cache(self): + pass + + def cacheReset(self, frame): + pass + def forward( self, image1, diff --git a/src/scenechange/scenechange.py b/src/scenechange/scenechange.py deleted file mode 100644 index 7b8bcc78..00000000 --- a/src/scenechange/scenechange.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import onnxruntime as ort -import os -import logging -import numpy as np - -from torch.nn import functional as F -from src.downloadModels import weightsDir, downloadModels, modelsMap -from src.coloredPrints import yellow - - -class SceneChange: - def __init__( - self, - half, - ): - self.half = half - self.loadModel() - - def loadModel(self): - filename = modelsMap( - "scenechange", - self.half, - ) - - if not os.path.exists(os.path.join(weightsDir, "scenechange", filename)): - modelPath = downloadModels( - "scenechange", - self.half, - ) - - else: - modelPath = os.path.join(weightsDir, "scenechange", filename) - - providers = ort.get_available_providers() - if "DmlExecutionProvider" in providers: - logging.info("DirectML provider available. Defaulting to DirectML") - self.model = ort.InferenceSession( - modelPath, providers=["DmlExecutionProvider"] - ) - else: - logging.info( - "DirectML provider not available, falling back to CPU, expect significantly worse performance, ensure that your drivers are up to date and your GPU supports DirectX 12" - ) - self.model = ort.InferenceSession( - modelPath, providers=["CPUExecutionProvider"] - ) - - self.deviceType = "cpu" - self.device = torch.device(self.deviceType) - self.firstRun = True - - @torch.inference_mode() - def processFrame(self, frame): - frame = frame.to(self.device) - frame = frame.unsqueeze(0) - frame = frame.permute(0, 3, 1, 2) - frame = F.interpolate(frame, size=(224, 224), mode="bilinear") - frame = frame - frame = frame.to(self.device).squeeze(0) - frame = frame.half() if self.half else frame.float() - return frame.numpy() - - @torch.inference_mode() - def run(self, frame): - if self.firstRun: - self.I0 = self.processFrame(frame) - self.firstRun = False - - self.I1 = self.processFrame(frame) - - inputs = np.ascontiguousarray(np.concatenate((self.I0, self.I1), 0)) - result = self.model.run(None, {"input": inputs})[0][0][0] * 255 - - print(yellow(f"SceneChange: {result}")) - - self.I0 = self.I1 - - return result > 0.93 # hardcoded for testing purposes diff --git a/src/unifiedInterpolate.py b/src/unifiedInterpolate.py index 0608c160..8a4ae953 100644 --- a/src/unifiedInterpolate.py +++ b/src/unifiedInterpolate.py @@ -2,6 +2,7 @@ import torch import logging import tensorrt as trt +import numpy as np from torch.nn import functional as F from .downloadModels import downloadModels, weightsDir, modelsMap @@ -21,6 +22,7 @@ def __init__( ensemble=False, nt=1, interpolateFactor=2, + sceneChange=False, ): """ Initialize the RIFE model @@ -33,6 +35,8 @@ def __init__( interpolateMethod (str): The method to use for interpolation. ensemble (bool): Whether to use ensemble mode. nt (int): The number of streams to use, not available for now. + interpolateFactor (int): The interpolation factor. + scenechange (bool): Whether to use scene change detection. """ self.half = half self.scale = 1.0 @@ -42,12 +46,19 @@ def __init__( self.ensemble = ensemble self.nt = nt self.interpolateFactor = interpolateFactor + self.sceneChange = sceneChange if self.width > 1920 or self.height > 1080: self.scale = 0.5 if self.half: - print(yellow("UHD and fp16 are not compatible with RIFE, defaulting to fp32")) - logging.info("UHD and fp16 for rife are not compatible due to flickering issues, defaulting to fp32") + print( + yellow( + "UHD and fp16 are not compatible with RIFE, defaulting to fp32" + ) + ) + logging.info( + "UHD and fp16 for rife are not compatible due to flickering issues, defaulting to fp32" + ) self.half = False self.handle_model() @@ -104,9 +115,7 @@ def handle_model(self): 3, self.height + self.padding[3], self.width + self.padding[1], - dtype=torch.float16 - if self.half - else torch.float32, + dtype=torch.float16 if self.half else torch.float32, device=self.device, ) @@ -115,46 +124,64 @@ def handle_model(self): 3, self.height + self.padding[3], self.width + self.padding[1], - dtype=torch.float16 - if self.half - else torch.float32, + dtype=torch.float16 if self.half else torch.float32, device=self.device, ) self.firstRun = True self.stream = torch.cuda.Stream() - + + if self.sceneChange: + self.sceneChangeProcess = SceneChange(self.half) + + @torch.inference_mode() + def cacheFrame(self): + self.I0.copy_(self.I1, non_blocking=True) + self.model.cache() + + @torch.inference_mode() + def cacheFrameReset(self): + self.I0.copy_(self.I1, non_blocking=True) + self.model.cacheReset(self.I0) @torch.inference_mode() def processFrame(self, frame): return ( - ( - frame.to(self.device, non_blocking=True, dtype=torch.float32) - .permute(2, 0, 1) - .unsqueeze(0) - if not self.half - else frame.to(self.device, non_blocking=True, dtype=torch.float16) - .permute(2, 0, 1) - .unsqueeze(0) - ) - .mul(1 / 255) - ) + frame.to(self.device, non_blocking=True, dtype=torch.float32) + .permute(2, 0, 1) + .unsqueeze(0) + if not self.half + else frame.to(self.device, non_blocking=True, dtype=torch.float16) + .permute(2, 0, 1) + .unsqueeze(0) + ).mul(1 / 255) @torch.inference_mode() def padFrame(self, frame): - return F.pad(frame, [0, self.padding[1], 0, self.padding[3]]) if self.padding != (0, 0, 0, 0) else frame + return ( + F.pad(frame, [0, self.padding[1], 0, self.padding[3]]) + if self.padding != (0, 0, 0, 0) + else frame + ) @torch.inference_mode() def run(self, frame, interpolateFactor, writeBuffer): with torch.cuda.stream(self.stream): - if self.firstRun is True: - self.I0 = self.processFrame(frame) - self.I0 = self.padFrame(self.I0) + if self.firstRun: + self.I0 = self.padFrame(self.processFrame(frame)) self.firstRun = False return - self.I1 = self.processFrame(frame) - self.I1 = self.padFrame(self.I1) + + self.I1 = self.padFrame(self.processFrame(frame)) + + if self.sceneChange: + if self.sceneChangeProcess.run(self.I0, self.I1): + for _ in range(interpolateFactor - 1): + writeBuffer.write(frame) + self.cacheFrameReset() + self.stream.synchronize() + return for i in range(interpolateFactor - 1): timestep = torch.full( @@ -163,15 +190,13 @@ def run(self, frame, interpolateFactor, writeBuffer): dtype=torch.float16 if self.half else torch.float32, device=self.device, ) - output = self.model( - self.I0, self.I1, timestep, interpolateFactor - ) + output = self.model(self.I0, self.I1, timestep, interpolateFactor) output = output[:, :, : self.height, : self.width] output = output.mul(255.0).squeeze(0).permute(1, 2, 0) self.stream.synchronize() writeBuffer.write(output) - self.I0.copy_(self.I1, non_blocking=True) + self.cacheFrame() class RifeTensorRT: @@ -184,6 +209,7 @@ def __init__( half: bool = True, ensemble: bool = False, nt: int = 1, + sceneChange: bool = False, ): """ Interpolates frames using TensorRT @@ -225,11 +251,18 @@ def __init__( self.ensemble = ensemble self.nt = nt self.model = None + self.sceneChange = sceneChange if self.width > 1920 or self.height > 1080: if self.half: - print(yellow("UHD and fp16 are not compatible with RIFE, defaulting to fp32")) - logging.info("UHD and fp16 for rife are not compatible due to flickering issues, defaulting to fp32") + print( + yellow( + "UHD and fp16 are not compatible with RIFE, defaulting to fp32" + ) + ) + logging.info( + "UHD and fp16 for rife are not compatible due to flickering issues, defaulting to fp32" + ) self.half = False self.handleModel() @@ -360,32 +393,44 @@ def handleModel(self): tensor_name = self.engine.get_tensor_name(i) if self.engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: self.context.set_input_shape(tensor_name, self.dummyInput.shape) - + self.firstRun = True + if self.sceneChange: + self.sceneChangeProcess = SceneChange(self.half) @torch.inference_mode() def processFrame(self, frame): return ( - ( - frame.to(self.device, non_blocking=True, dtype=torch.float32) - .permute(2, 0, 1) - .unsqueeze(0) - if not self.half - else frame.to(self.device, non_blocking=True, dtype=torch.float16) - .permute(2, 0, 1) - .unsqueeze(0) - ) - .mul(1 / 255) - ) + frame.to(self.device, non_blocking=True, dtype=torch.float32) + .permute(2, 0, 1) + .unsqueeze(0) + if not self.half + else frame.to(self.device, non_blocking=True, dtype=torch.float16) + .permute(2, 0, 1) + .unsqueeze(0) + ).mul(1 / 255) + + @torch.inference_mode() + def cacheFrame(self): + self.I0.copy_(self.I1, non_blocking=True) @torch.inference_mode() def run(self, frame, interpolateFactor, writeBuffer): with torch.cuda.stream(self.stream): - if self.firstRun is True: + if self.firstRun: self.I0 = self.processFrame(frame) self.firstRun = False return self.I1 = self.processFrame(frame) + + if self.sceneChange: + if self.sceneChangeProcess.run(self.I0, self.I1): + for _ in range(interpolateFactor - 1): + writeBuffer.write(frame) + self.cacheFrame() + self.stream.synchronize() + return + for i in range(interpolateFactor - 1): timestep = torch.full( (1, 1, self.height, self.width), @@ -398,5 +443,176 @@ def run(self, frame, interpolateFactor, writeBuffer): output = self.dummyOutput.squeeze_(0).permute(1, 2, 0).mul_(255) self.stream.synchronize() writeBuffer.write(output) + + self.cacheFrame() + +class rifeNCNN: + def __init__( + self, interpolateMethod, ensemble=False, nt=1, width=1920, height=1080, sceneChange=False, half=True + ): + self.interpolateMethod = interpolateMethod + self.nt = nt + self.height = height + self.width = width + self.ensemble = ensemble + self.sceneChange = sceneChange + self.half = half + + UHD = True if width >= 3840 or height >= 2160 else False + scale = 2 if UHD else 1 + + from rife_ncnn_vulkan_python import Rife + + match interpolateMethod: + case "rife4.15-ncnn" | "rife-ncnn": + self.interpolateMethod = "rife-v4.15-ncnn" + case "rife4.6-ncnn": + self.interpolateMethod = "rife-v4.6-ncnn" + case "rife4.15-lite-ncnn": + self.interpolateMethod = "rife-v4.15-lite-ncnn" + case "rife4.16-lite-ncnn": + self.interpolateMethod = "rife-v4.16-lite-ncnn" + case "rife4.17-lite-ncnn": + self.interpolateMethod = "rife-v4.17-lite-ncnn" + + self.filename = modelsMap( + self.interpolateMethod, + ensemble=self.ensemble, + ) + + if self.filename.endswith("-ncnn.zip"): + self.filename = self.filename[:-9] + elif self.filename.endswith("-ncnn"): + self.filename = self.filename[:-5] + + if not os.path.exists( + os.path.join(weightsDir, self.interpolateMethod, self.filename) + ): + modelPath = downloadModels( + model=self.interpolateMethod, + ensemble=self.ensemble, + ) + else: + modelPath = os.path.join(weightsDir, self.interpolateMethod, self.filename) + + if modelPath.endswith("-ncnn.zip"): + modelPath = modelPath[:-9] + elif modelPath.endswith("-ncnn"): + modelPath = modelPath[:-5] + + self.rife = Rife( + gpuid=0, + model=modelPath, + scale=scale, + tta_mode=False, + tta_temporal_mode=False, + uhd_mode=UHD, + num_threads=self.nt, + ) + + self.frame1 = None + self.shape = (self.height, self.width) + + if self.sceneChange: + self.sceneChangeProcess = SceneChange(self.half) + + def cacheFrame(self): + self.frame1 = self.frame2 + + def run(self, frame, interpolateFactor, writeBuffer): + if self.frame1 is None: + self.frame1 = frame.cpu().numpy().astype("uint8") + return False + + self.frame2 = frame.cpu().numpy().astype("uint8") + + if self.sceneChange: + if self.sceneChangeProcess.runNumpy(self.frame1, self.frame2): + for _ in range(interpolateFactor - 1): + writeBuffer.write(frame) + self.cacheFrame() + return + + for i in range(interpolateFactor - 1): + timestep = (i + 1) * 1 / interpolateFactor + + output = self.rife.process_cv2(self.frame1, self.frame2, timestep=timestep) + + output = torch.from_numpy(output).to(frame.device) + writeBuffer.write(output) + + self.cacheFrame() + +class SceneChange: + def __init__( + self, + half, + ): + self.half = half + + import onnxruntime as ort + import cv2 + + self.ort = ort + self.cv2 = cv2 + + self.loadModel() + + def loadModel(self): + filename = modelsMap( + "scenechange", + self.half, + ) + + if not os.path.exists(os.path.join(weightsDir, "scenechange", filename)): + modelPath = downloadModels( + "scenechange", + self.half, + ) + + else: + modelPath = os.path.join(weightsDir, "scenechange", filename) + + providers = self.ort.get_available_providers() + if "DmlExecutionProvider" in providers: + logging.info( + "DirectML provider available for scenechange detection. Defaulting to DirectML" + ) + self.model = self.ort.InferenceSession( + modelPath, providers=["DmlExecutionProvider"] + ) + else: + logging.info( + "DirectML provider not available for scenechange detection, falling back to CPU, expect significantly worse performance, ensure that your drivers are up to date and your GPU supports DirectX 12" + ) + self.model = self.ort.InferenceSession( + modelPath, providers=["CPUExecutionProvider"] + ) + + self.firstRun = True + + @torch.inference_mode() + def run(self, frame0, frame1): + if not self.half: + frame0 = F.interpolate(frame0.float(), size=(224, 224), mode="bilinear").squeeze(0).cpu().numpy() + frame1 = F.interpolate(frame1.float(), size=(224, 224), mode="bilinear").squeeze(0).cpu().numpy() + else: + frame0 = F.interpolate(frame0.float(), size=(224, 224), mode="bilinear").squeeze(0).half().cpu().numpy() + frame1 = F.interpolate(frame1.float(), size=(224, 224), mode="bilinear").squeeze(0).half().cpu().numpy() + + inputs = np.ascontiguousarray(np.concatenate((frame0, frame1), 0)) + result = self.model.run(None, {"input": inputs})[0][0][0] + return result > 0.93 + + def runNumpy(self, frame1, frame2): + frame0 = self.cv2.resize(frame1, (224, 224)).transpose(2, 0, 1) + frame1 = self.cv2.resize(frame2, (224, 224)).transpose(2, 0, 1) + + if self.half: + frame0 = frame0.astype(np.float16) + frame1 = frame1.astype(np.float16) + + inputs = np.ascontiguousarray(np.concatenate((frame0, frame1), 0)) + result = self.model.run(None, {"input": inputs})[0][0][0] - self.I0.copy_(self.I1, non_blocking=True) + return result > 0.93 diff --git a/src/unifiedUpscale.py b/src/unifiedUpscale.py index 6dd5d7e6..bd6037c6 100644 --- a/src/unifiedUpscale.py +++ b/src/unifiedUpscale.py @@ -366,14 +366,13 @@ def handleModel(self): (1, 3, self.height, self.width), device=self.deviceType, dtype=self.torchDType, - ) - self.dummyInput = self.dummyInput.contiguous() + ).contiguous() self.dummyOutput = torch.zeros( (1, 3, self.height * self.upscaleFactor, self.width * self.upscaleFactor), device=self.deviceType, dtype=self.torchDType, - ) + ).contiguous() self.IoBinding.bind_output( name="output",