Skip to content

Commit

Permalink
remove bloat
Browse files Browse the repository at this point in the history
  • Loading branch information
NevermindNilas committed May 2, 2024
1 parent 3aa6172 commit d033039
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 23 deletions.
21 changes: 3 additions & 18 deletions src/unifiedDenoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,6 @@ def handleModel(self):
torch.set_default_dtype(torch.bfloat16)
self.model.bfloat16()

self.padWidth = 0 if self.width % 8 == 0 else 8 - (self.width % 8)
self.padHeight = 0 if self.height % 8 == 0 else 8 - (self.height % 8)

@torch.inference_mode()
def pad_frame(self, frame):
frame = F.pad(frame, [0, self.padWidth, 0, self.padHeight])
return frame

@torch.inference_mode()
def run(self, frame: np.ndarray) -> np.ndarray:
"""
Expand All @@ -123,24 +115,17 @@ def run(self, frame: np.ndarray) -> np.ndarray:
.unsqueeze(0)
.float()
.mul_(1 / 255)
.to(self.device)
)

frame = frame.contiguous(memory_format=torch.channels_last)

if self.isCudaAvailable:
if self.half:
if self.precision == "fp16":
frame = frame.cuda().half()
frame = frame.half()
elif self.precision == "bfloat16":
frame = frame.cuda().bfloat16()
else:
frame = frame.cuda()

if self.padWidth != 0 or self.padHeight != 0:
frame = self.pad_frame(frame)
frame = frame.bfloat16()

frame = self.model(frame)
frame = frame[:, :, : self.height, : self.width]
frame = frame.squeeze(0).permute(1, 2, 0).mul_(255).byte()

return frame.cpu().numpy()
6 changes: 1 addition & 5 deletions src/unifiedUpscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,6 @@ def handleModel(self):
torch.set_default_dtype(torch.float16)
self.model.half()

@torch.inference_mode()
def padFrame(self, frame):
frame = F.pad(frame, [0, self.padWidth, 0, self.padHeight])
return frame

@torch.inference_mode()
def run(self, frame: np.ndarray) -> np.ndarray:
"""
Expand All @@ -116,6 +111,7 @@ def run(self, frame: np.ndarray) -> np.ndarray:
.mul_(1 / 255)
.to(self.device)
)

frame = frame.half() if self.half and self.isCudaAvailable else frame
frame = self.model(frame)

Expand Down

0 comments on commit d033039

Please sign in to comment.