Skip to content

Commit

Permalink
Merge pull request #34 from the-database/dev
Browse files Browse the repository at this point in the history
#33 fix grayscale tiling error
  • Loading branch information
the-database authored Jul 31, 2024
2 parents 5b839cd + 9e89379 commit bff5015
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..upscale.auto_split import Split, Tiler, auto_split
from .utils import safe_cuda_cache_empty
from nodes.utils.utils import get_h_w_c


def _into_standard_image_form(t: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -92,23 +93,23 @@ def upscale(img: np.ndarray, _: object):

input_tensor = None
try:
_, _, input_channels = get_h_w_c(img)
# convert to tensor
input_tensor = _into_tensor(img, device, dtype)
# expand grayscale tensor to match model input channels
input_ndim = input_tensor.ndim
if input_ndim == 2 and model.input_channels > 1:
input_tensor = input_tensor.unsqueeze(-1).repeat(1, 1, model.input_channels)
# input_ndim = input_tensor.ndim
if input_channels == 1:
input_tensor = input_tensor.repeat(1, 1, model.input_channels)
else:
input_tensor = _rgb_to_bgr(input_tensor)
input_tensor = _into_batched_form(input_tensor)

# inference
output_tensor = model(input_tensor)

# convert back to numpy
output_tensor = _into_standard_image_form(output_tensor)
if input_ndim == 2:
output_tensor = output_tensor[..., 0]
if input_channels == 1:
output_tensor = output_tensor[:, :, 0].unsqueeze(-1)
else:
output_tensor = _rgb_to_bgr(output_tensor)
result = output_tensor.detach().cpu().detach().float().numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def convenient_upscale(
# skip all conversions for grayscale to improve performance by reducing the amount of data that needs to be copied
# instead we do the color conversions on the tensors after they're already on the gpu
if in_img_c == 1:
if img.ndim == 2:
img = np.expand_dims(img, axis=-1)
return upscale(img)

return as_target_channels(
Expand Down

0 comments on commit bff5015

Please sign in to comment.