From 9e893791f6505d23571e6a8317ed9e4e1129cbf3 Mon Sep 17 00:00:00 2001 From: the-database <25811902+the-database@users.noreply.github.com> Date: Wed, 31 Jul 2024 15:51:59 -0400 Subject: [PATCH] #33 fix grayscale tiling error --- .../backend/src/nodes/impl/pytorch/auto_split.py | 13 +++++++------ .../src/nodes/impl/upscale/convenient_upscale.py | 2 ++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/MangaJaNaiConverterGui/backend/src/nodes/impl/pytorch/auto_split.py b/MangaJaNaiConverterGui/backend/src/nodes/impl/pytorch/auto_split.py index 139d566..4ed3c1b 100644 --- a/MangaJaNaiConverterGui/backend/src/nodes/impl/pytorch/auto_split.py +++ b/MangaJaNaiConverterGui/backend/src/nodes/impl/pytorch/auto_split.py @@ -10,6 +10,7 @@ from ..upscale.auto_split import Split, Tiler, auto_split from .utils import safe_cuda_cache_empty +from nodes.utils.utils import get_h_w_c def _into_standard_image_form(t: torch.Tensor) -> torch.Tensor: @@ -92,23 +93,23 @@ def upscale(img: np.ndarray, _: object): input_tensor = None try: + _, _, input_channels = get_h_w_c(img) # convert to tensor input_tensor = _into_tensor(img, device, dtype) # expand grayscale tensor to match model input channels - input_ndim = input_tensor.ndim - if input_ndim == 2 and model.input_channels > 1: - input_tensor = input_tensor.unsqueeze(-1).repeat(1, 1, model.input_channels) + # input_ndim = input_tensor.ndim + if input_channels == 1: + input_tensor = input_tensor.repeat(1, 1, model.input_channels) else: input_tensor = _rgb_to_bgr(input_tensor) input_tensor = _into_batched_form(input_tensor) - # inference output_tensor = model(input_tensor) # convert back to numpy output_tensor = _into_standard_image_form(output_tensor) - if input_ndim == 2: - output_tensor = output_tensor[..., 0] + if input_channels == 1: + output_tensor = output_tensor[:, :, 0].unsqueeze(-1) else: output_tensor = _rgb_to_bgr(output_tensor) result = output_tensor.detach().cpu().detach().float().numpy() diff --git a/MangaJaNaiConverterGui/backend/src/nodes/impl/upscale/convenient_upscale.py b/MangaJaNaiConverterGui/backend/src/nodes/impl/upscale/convenient_upscale.py index b7d2b37..70cf8be 100644 --- a/MangaJaNaiConverterGui/backend/src/nodes/impl/upscale/convenient_upscale.py +++ b/MangaJaNaiConverterGui/backend/src/nodes/impl/upscale/convenient_upscale.py @@ -95,6 +95,8 @@ def convenient_upscale( # skip all conversions for grayscale to improve performance by reducing the amount of data that needs to be copied # instead we do the color conversions on the tensors after they're already on the gpu if in_img_c == 1: + if img.ndim == 2: + img = np.expand_dims(img, axis=-1) return upscale(img) return as_target_channels(