Skip to content

Commit

Permalink
fix upscale cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
NevermindNilas committed Sep 28, 2024
1 parent 3ff90ee commit be02fff
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.MD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- More accurate progress information post-process for comparison purposes.
- All executables now include precompiled `.pyc` code to improve startup times.
- Restructured interpolation pipeline for TensorRT, yielding up to `25%` performance gains.
- Full CUDA workflow now uses `cuda pinned memory workflow` to increase data movement efficiency.
- Full CUDA workflow now uses `cuda pinned memory workflow` to increase data transfer efficiency.
- Upgraded `Span` and `Span-TensorRT` models to `Spanimation-V2` - special thanks to @TNTWise.

#### Improved
Expand All @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- System checks are now minimal and log only essential information.
- System checks are now conditional based on the `--benchmark` flag.
- Slight adjustments to `getFFMPEG.py`.
- Improved Upscale-Cuda Code.

#### Removed
- Removed `--audio` and related features in favor of `pymediainfo`'s automated detection.
Expand Down
57 changes: 39 additions & 18 deletions src/unifiedUpscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,32 +97,53 @@ def handleModel(self):
)

self.model = self.model.model.to(memory_format=torch.channels_last)
self.normStream = torch.cuda.Stream()
self.dummyInput = (
torch.zeros(
(1, 3, self.height, self.width),
device=self.device,
dtype=torch.float16 if self.half else torch.float32,
)
.contiguous()
.to(memory_format=torch.channels_last)
)

def __call__(self, frame: torch.tensor) -> torch.tensor:
with torch.cuda.stream(self.stream):
if self.upscaleSkip is not None:
if self.upscaleSkip(frame):
self.skippedCounter += 1
return self.prevFrame

frame = (
frame.to(
self.device,
non_blocking=True,
dtype=torch.float16 if self.half else torch.float32,
)
@torch.inference_mode()
def processFrame(self, frame):
with torch.cuda.stream(self.normStream):
self.dummyInput.copy_(
frame.to(dtype=torch.float16 if self.half else torch.float32)
.permute(2, 0, 1)
.unsqueeze(0)
.to(memory_format=torch.channels_last)
.mul(1 / 255)
.mul(1 / 255),
non_blocking=True,
)
self.normStream.synchronize()

@torch.inference_mode()
def __call__(self, frame: torch.tensor) -> torch.tensor:
if self.upscaleSkip is not None:
if self.upscaleSkip(frame):
self.skippedCounter += 1
return self.prevFrame

self.processFrame(frame)
with torch.cuda.stream(self.stream):
output = (
self.model(self.dummyInput)
.squeeze_(0)
.clamp(0, 1)
.mul_(255)
.permute(1, 2, 0)
)
output = self.model(frame).squeeze(0).mul(255).permute(1, 2, 0)
self.stream.synchronize()

if self.upscaleSkip is not None:
if self.upscaleSkip is not None:
with torch.cuda.stream(self.stream):
self.prevFrame.copy_(output, non_blocking=True)
self.stream.synchronize()

return output
return output

def getSkippedCounter(self):
return self.skippedCounter
Expand Down

0 comments on commit be02fff

Please sign in to comment.