Skip to content

Commit

Permalink
add Scene Change Detection
Browse files Browse the repository at this point in the history
  • Loading branch information
NevermindNilas committed Jun 4, 2024
1 parent 807c98a commit ccf1109
Show file tree
Hide file tree
Showing 13 changed files with 356 additions and 195 deletions.
2 changes: 1 addition & 1 deletion TheAnimeScripter.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ var panelGlobal = this;
var TheAnimeScripter = (function () {

var scriptName = "TheAnimeScripter";
var scriptVersion = "v1.8.1";
var scriptVersion = "v1.8.2";

// Default Values for the settings
var outputFolder = app.settings.haveSetting(scriptName, "outputFolder") ? app.settings.getSetting(scriptName, "outputFolder") : "undefined";
Expand Down
2 changes: 1 addition & 1 deletion gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
format='%(asctime)s %(levelname)s %(name)s %(message)s')
logger=logging.getLogger(__name__)

TITLE = "The Anime Scripter - 1.8.1 (Alpha)"
TITLE = "The Anime Scripter - 1.8.2 (Alpha)"
W, H = 1280, 720

if getattr(sys, "frozen", False):
Expand Down
14 changes: 3 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
else:
mainPath = os.path.dirname(os.path.abspath(__file__))

scriptVersion = "1.8.1"
scriptVersion = "1.8.2"
warnings.filterwarnings("ignore")


Expand Down Expand Up @@ -141,8 +141,7 @@ def __init__(self, args):
def processFrame(self, frame):
try:
if self.dedup and self.dedup_method != "ffmpeg":
result = self.dedup_process.run(frame)
if result:
if self.dedup_process.run(frame):
self.dedupCount += 1
return

Expand All @@ -153,13 +152,6 @@ def processFrame(self, frame):
frame = self.upscale_process.run(frame)

if self.interpolate:
if self.scenechange:
result = self.scenechange_process.run(frame)
if result:
for _ in (self.interpolate_factor - 1):
self.writeBuffer.write(frame)
return

self.interpolate_process.run(
frame, self.interpolate_factor, self.writeBuffer
)
Expand Down Expand Up @@ -500,7 +492,7 @@ def start(self):
videoFiles = [
os.path.join(args.input, file)
for file in os.listdir(args.input)
if file.endswith((".mp4", ".mkv", ".mov", ".avi"))
if file.endswith((".mp4", ".mkv", ".mov", ".avi", ".webm"))
]
toPrint = f"Processing {len(videoFiles)} files"
logging.info(toPrint)
Expand Down
77 changes: 49 additions & 28 deletions src/gmfss/gmfss_fortuna_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,33 @@

class GMFSS:
def __init__(
self, interpolation_factor, half, width, height, ensemble=False, nt=1
self,
interpolation_factor,
half,
width,
height,
ensemble=False,
nt=1,
sceneChange=False,
):
self.width = width
self.height = height
self.half = half
self.interpolation_factor = interpolation_factor
self.ensemble = ensemble
self.nt = nt
self.sceneChange = sceneChange

ph = ((self.height - 1) // 32 + 1) * 32
pw = ((self.width - 1) // 32 + 1) * 32
self.padding = (0, pw - self.width, 0, ph - self.height)

if self.width > 1920 or self.height > 1080:
print(yellow("Warning: Output Resolution is higher than 1080p. Expect significant slowdowns or no functionality at all due to VRAM Constraints when using GMFSS, in case of issues consider switching to RIFE."))
print(
yellow(
"Warning: Output Resolution is higher than 1080p. Expect significant slowdowns or no functionality at all due to VRAM Constraints when using GMFSS, in case of issues consider switching to RIFE."
)
)
self.scale = 0.5
else:
self.scale = 1
Expand Down Expand Up @@ -58,8 +70,8 @@ def handle_model(self):

torch.set_grad_enabled(False)
if self.isCudaAvailable:
#self.stream = [torch.cuda.Stream() for _ in range(self.nt)]
#self.current_stream = 0
# self.stream = [torch.cuda.Stream() for _ in range(self.nt)]
# self.current_stream = 0
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if self.half:
Expand All @@ -80,9 +92,7 @@ def handle_model(self):
3,
self.height + self.padding[3],
self.width + self.padding[1],
dtype=torch.float16
if self.half
else torch.float32,
dtype=torch.float16 if self.half else torch.float32,
device=self.device,
)

Expand All @@ -91,16 +101,18 @@ def handle_model(self):
3,
self.height + self.padding[3],
self.width + self.padding[1],
dtype=torch.float16
if self.half
else torch.float32,
dtype=torch.float16 if self.half else torch.float32,
device=self.device,
)

self.stream = torch.cuda.Stream()
self.firstRun = True

@torch.inference_mode()
if self.sceneChange:
from src.unifiedInterpolate import SceneChange
self.sceneChangeProcess = SceneChange(self.half)

@torch.inference_mode()
def make_inference(self, n):
"""
if self.isCudaAvailable:
Expand All @@ -114,36 +126,38 @@ def make_inference(self, n):
)
output = self.model(self.I0, self.I1, timestep)

#if self.isCudaAvailable:
#torch.cuda.synchronize(self.stream[self.current_stream])
#self.current_stream = (self.current_stream + 1) % len(self.stream)
# if self.isCudaAvailable:
# torch.cuda.synchronize(self.stream[self.current_stream])
# self.current_stream = (self.current_stream + 1) % len(self.stream)

if self.padding != (0, 0, 0, 0):
output = output[..., : self.height, : self.width]

return output.squeeze(0).permute(1, 2, 0).mul_(255)

@torch.inference_mode()
def cacheFrame(self):
self.I0.copy_(self.I1, non_blocking=True)

@torch.inference_mode()
def processFrame(self, frame):
return (
(
frame.to(self.device)
.permute(2, 0, 1)
.unsqueeze(0)
.float()
frame.to(self.device).permute(2, 0, 1).unsqueeze(0).float()
if not self.half
else frame.to(self.device)
.permute(2, 0, 1)
.unsqueeze(0)
.half()
else frame.to(self.device).permute(2, 0, 1).unsqueeze(0).half()
)
.mul(1 / 255)
.contiguous()
)

@torch.inference_mode()
def padFrame(self, frame):
return F.pad(frame, [0, self.padding[1], 0, self.padding[3]]) if self.padding != (0, 0, 0, 0) else frame
return (
F.pad(frame, [0, self.padding[1], 0, self.padding[3]])
if self.padding != (0, 0, 0, 0)
else frame
)

@torch.inference_mode()
def run(self, frame, interpolateFactor, writeBuffer):
Expand All @@ -156,18 +170,25 @@ def run(self, frame, interpolateFactor, writeBuffer):
self.I1 = self.processFrame(frame)
self.I1 = self.padFrame(self.I1)

if self.sceneChange:
if self.sceneChangeProcess.run(self.I0, self.I1):
for _ in range(interpolateFactor - 1):
writeBuffer.write(frame)
self.cacheFrame()
self.stream.synchronize()
return

for i in range(interpolateFactor - 1):
timestep = torch.tensor(
(i + 1) * 1.0 / self.interpolation_factor,
dtype=self.dtype,
device=self.device,
)
output = self.model(
self.I0, self.I1, timestep
)
output = self.model(self.I0, self.I1, timestep)
output = output[:, :, : self.height, : self.width]
output = output.mul(255.0).squeeze(0).permute(1, 2, 0)
self.stream.synchronize()
writeBuffer.write(output)

self.cacheFrame()

self.I0.copy_(self.I1, non_blocking=True)
14 changes: 6 additions & 8 deletions src/initializeModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def initializeModels(self):
self.ensemble,
self.nt,
self.interpolate_factor,
self.scenechange,
)
case "gmfss":
from src.gmfss.gmfss_fortuna_union import GMFSS
Expand All @@ -226,6 +227,7 @@ def initializeModels(self):
outputHeight,
self.ensemble,
self.nt,
self.scenechange,
)

case (
Expand All @@ -236,14 +238,16 @@ def initializeModels(self):
| "rife4.16-lite-ncnn"
| "rife4.17-ncnn"
):
from src.rifencnn.rifencnn import rifeNCNN
from src.unifiedInterpolate import rifeNCNN

interpolate_process = rifeNCNN(
self.interpolate_method,
self.ensemble,
self.nt,
outputWidth,
outputHeight,
self.scenechange,
self.half,
)

case (
Expand All @@ -263,6 +267,7 @@ def initializeModels(self):
self.half,
self.ensemble,
self.nt,
self.scenechange,
)

if self.denoise:
Expand Down Expand Up @@ -317,13 +322,6 @@ def initializeModels(self):

# case ffmpeg, ffmpeg works on decode, refer to ffmpegSettings.py ReadBuffer class.

if self.scenechange:
from src.scenechange.scenechange import SceneChange

scenechange_process = SceneChange(
self.half,
)

return (
outputWidth,
outputHeight,
Expand Down
11 changes: 7 additions & 4 deletions src/rifearches/IFNet_rife415.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,25 @@ def __init__(self, ensemble=False, scale=1):
self.ensemble = ensemble
self.counter = 1

def cache(self):
self.f0.copy_(self.f1, non_blocking=True)

def cacheReset(self, frame):
self.f0 = self.encode(frame[:, :3])

def forward(self, img0, img1, timestep, interpolateFactor = 2):
# Overengineered but it seems to work
if interpolateFactor == 2:
if self.f0 is None:
self.f0 = self.encode(img0[:, :3])
else:
self.f0.copy_(self.f1, non_blocking=True)

self.f1 = self.encode(img1[:, :3])
else:
if self.counter == interpolateFactor:
self.counter = 1
if self.f0 is None:
self.f0 = self.encode(img0[:, :3])
else:
self.f0.copy_(self.f1, non_blocking=True)

self.f1 = self.encode(img1[:, :3])
else:
if self.f0 is None or self.f1 is None:
Expand Down
10 changes: 6 additions & 4 deletions src/rifearches/IFNet_rife415lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,24 @@ def __init__(self, ensemble, scale=1):
self.ensemble = ensemble
self.counter = 1

def cache(self):
self.f0.copy_(self.f1, non_blocking=True)

def cacheReset(self, frame):
self.f0 = self.encode(frame[:, :3])

def forward(self, img0, img1, timestep, interpolateFactor = 2):
# Overengineered but it seems to work
if interpolateFactor == 2:
if self.f0 is None:
self.f0 = self.encode(img0[:, :3])
else:
self.f0.copy_(self.f1, non_blocking=True)

self.f1 = self.encode(img1[:, :3])
else:
if self.counter == interpolateFactor:
self.counter = 1
if self.f0 is None:
self.f0 = self.encode(img0[:, :3])
else:
self.f0.copy_(self.f1, non_blocking=True)
self.f1 = self.encode(img1[:, :3])
else:
if self.f0 is None or self.f1 is None:
Expand Down
10 changes: 6 additions & 4 deletions src/rifearches/IFNet_rife416lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,24 @@ def __init__(self, ensemble=False, scale=1):
self.ensemble = ensemble
self.counter = 1

def cache(self):
self.f0.copy_(self.f1, non_blocking=True)

def cacheReset(self, frame):
self.f0 = self.encode(frame[:, :3])

def forward(self, img0, img1, timestep, interpolateFactor = 2):
# Overengineered but it seems to work
if interpolateFactor == 2:
if self.f0 is None:
self.f0 = self.encode(img0[:, :3])
else:
self.f0.copy_(self.f1, non_blocking=True)

self.f1 = self.encode(img1[:, :3])
else:
if self.counter == interpolateFactor:
self.counter = 1
if self.f0 is None:
self.f0 = self.encode(img0[:, :3])
else:
self.f0.copy_(self.f1, non_blocking=True)
self.f1 = self.encode(img1[:, :3])
else:
if self.f0 is None or self.f1 is None:
Expand Down
13 changes: 7 additions & 6 deletions src/rifearches/IFNet_rife417.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,23 @@ def __init__(self, ensemble=False, scale=1):
self.ensemble = ensemble
self.counter = 1

def cache(self):
self.f0.copy_(self.f1, non_blocking=True)

def cacheReset(self, frame):
self.f0 = self.encode(frame[:, :3])

def forward(self, img0, img1, timestep, interpolateFactor = 2):
# Overengineered but it seems to work
if interpolateFactor == 2:
if self.f0 is None:
self.f0 = self.encode(img0[:, :3])
else:
self.f0.copy_(self.f1, non_blocking=True)

self.f0 = self.encode(img0[:, :3])
self.f1 = self.encode(img1[:, :3])
else:
if self.counter == interpolateFactor:
self.counter = 1
if self.f0 is None:
self.f0 = self.encode(img0[:, :3])
else:
self.f0.copy_(self.f1, non_blocking=True)
self.f1 = self.encode(img1[:, :3])
else:
if self.f0 is None or self.f1 is None:
Expand Down
Loading

0 comments on commit ccf1109

Please sign in to comment.