Skip to content

Commit

Permalink
simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
NevermindNilas committed Sep 23, 2024
1 parent 03bfdd8 commit eadc1ea
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 78 deletions.
6 changes: 3 additions & 3 deletions PARAMETERS.MD
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ The Anime Scripter is a powerful tool for enhancing and manipulating videos with

#### Multiple Input Files
```sh
--input "G:\TheAnimeScripter\input.mp4;G:\TheAnimeScripter\input - Copy.mp4"
--input "G:\TheAnimeScripter\input.mp4;G:\TheAnimeScripter\test.mp4"

# Or if the input is a txt, this is how the process should look like
--input "G:\TheAnimeScripter\test.txt"

# Contents of TXT:
"G:\TheAnimeScripter\input.mp4" # separate by a new line here
"G:\TheAnimeScripter\input - Copy.mp4"
"G:\TheAnimeScripter\test1.mp4" # separate by a new line here
"G:\TheAnimeScripter\test2.mp4"
```
## Processing Options

Expand Down
2 changes: 1 addition & 1 deletion src/presetLogic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from .coloredPrints import green, bold, green
from .coloredPrints import green, bold


def createPreset(args, mainPath: str):
Expand Down
141 changes: 67 additions & 74 deletions src/unifiedInterpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,12 @@ def handleModel(self):
dtype=self.dType,
)

self.testOutput = torch.zeros(
(3, self.height, self.width),
device=self.device,
dtype=self.dType,
)

self.bindings = [self.dummyInput.data_ptr(), self.dummyOutput.data_ptr()]

for i in range(self.engine.num_io_tensors):
Expand All @@ -399,97 +405,84 @@ def handleModel(self):
self.context.set_input_shape(tensor_name, self.dummyInput.shape)

self.firstRun = True
self.useI0AsSource = True
self.normStream = torch.cuda.Stream()
self.useI0 = True

@torch.inference_mode()
def processFrame(self, frame, toNorm):
match toNorm:
case "I0":
with torch.cuda.stream(self.normStream):
self.I0.copy_(
frame.to(
dtype=self.dType,
non_blocking=True,
)
.permute(2, 0, 1)
.unsqueeze(0)
.mul(1 / 255),
non_blocking=True,
)
self.normStream.synchronize()
case "I1":
with torch.cuda.stream(self.normStream):
self.I1.copy_(
frame.to(
dtype=self.dType,
non_blocking=True,
)
.permute(2, 0, 1)
.unsqueeze(0)
.mul(1 / 255),
non_blocking=True,
)
self.normStream.synchronize()
self.timesteps = torch.empty(
(self.interpolateFactor - 1, 1, 1, self.height, self.width),
dtype=self.dType,
device=self.device,
)

case "cache":
with torch.cuda.stream(self.normStream):
self.I0.copy_(
self.I1,
non_blocking=True,
)
self.normStream.synchronize()
for i in range(self.interpolateFactor - 1):
self.timesteps[i] = torch.full(
(1, 1, self.height, self.width),
(i + 1) * 1.0 / self.interpolateFactor,
dtype=self.dType,
device=self.device,
)

case "dummy":
with torch.cuda.stream(self.normStream):
self.dummyInput.copy_(
torch.cat(
[
self.I0,
self.I1,
frame,
],
dim=1,
),
non_blocking=True,
)
self.normStream.synchronize()
@torch.jit.script
def normalizeFrame(frame: torch.Tensor) -> torch.Tensor:
return frame.permute(2, 0, 1).unsqueeze(0).mul(1 / 255)

case "output":
with torch.cuda.stream(self.normStream):
output = self.dummyOutput.squeeze(0).permute(1, 2, 0).mul(255)
self.normStream.synchronize()
return output
@torch.jit.script
def normalizeOutput(output: torch.Tensor) -> torch.Tensor:
return output.squeeze(0).permute(1, 2, 0).mul(255)

@torch.inference_mode()
def cacheFrameReset(self, frame):
self.I0.copy_(self.processFrame(frame), non_blocking=True)
self.useI0AsSource = True
def processFrame(self, frame: torch.Tensor):
if self.useI0:
with torch.cuda.stream(self.normStream):
self.I0.copy_(
self.normalizeFrame(frame.to(self.dType)), non_blocking=True
)
self.normStream.synchronize()
else:
with torch.cuda.stream(self.normStream):
self.I1.copy_(
self.normalizeFrame(frame.to(self.dType)), non_blocking=True
)
self.normStream.synchronize()

@torch.inference_mode()
def processDummyFrame(self, frame: torch.Tensor, timestep: torch.Tensor):
with torch.cuda.stream(self.normStream):
self.dummyInput.copy_(
torch.cat(
[
self.I0 if self.useI0 else self.I1,
self.I1 if self.useI0 else self.I0,
timestep,
],
dim=1,
),
non_blocking=True,
)
self.normStream.synchronize()

@torch.inference_mode()
def processOutput(self):
with torch.cuda.stream(self.normStream):
output = self.normalizeOutput(self.dummyOutput)
self.normStream.synchronize()
return output

@torch.inference_mode()
def __call__(self, frame, interpQueue):
if self.firstRun:
self.processFrame(frame, "I0")
self.processFrame(frame)
self.firstRun = False
return

self.processFrame(frame, "I1")
for i in range(self.interpolateFactor - 1):
timestep = torch.full(
(1, 1, self.height, self.width),
(i + 1) * 1 / self.interpolateFactor,
dtype=self.dType,
device=self.device,
)

self.processFrame(timestep, "dummy")
self.processFrame(frame)
for timestep in self.timesteps:
self.processDummyFrame(frame=None, timestep=timestep)
self.context.execute_async_v3(stream_handle=self.stream.cuda_stream)
self.stream.synchronize()
output = self.processFrame(None, "output")

interpQueue.put(output)
interpQueue.put(self.processOutput())

self.processFrame(frame, "cache")
self.useI0 = not self.useI0


class RifeNCNN:
Expand Down
File renamed without changes.

0 comments on commit eadc1ea

Please sign in to comment.