From 411ddbb2f173adc603780ac350787d0308e28ddf Mon Sep 17 00:00:00 2001 From: ccareaga Date: Sun, 22 Sep 2024 13:25:58 -0700 Subject: [PATCH] adding intial updated model loading and inference code --- README.md | 41 ++++--- intrinsic/model_util.py | 48 -------- intrinsic/pipeline.py | 254 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 260 insertions(+), 83 deletions(-) delete mode 100644 intrinsic/model_util.py diff --git a/README.md b/README.md index ac0d8bb..0ba9312 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ This repository contains the code for the following papers: **Colorful Diffuse Intrinsic Image Decomposition in the Wild**, [Chris Careaga](https://ccareaga.github.io/) and [Yağız Aksoy](https://yaksoy.github.io), ACM Transactions on Graphics, 2024 \ -(Paper and video coming soon!) +[Paper](https://yaksoy.github.io/papers/TOG24-ColorfulShading.pdf) | [Supplementary](https://yaksoy.github.io/papers/TOG24-ColorfulShading-supp.pdf) **Intrinsic Image Decomposition via Ordinal Shading**, [Chris Careaga](https://ccareaga.github.io/) and [Yağız Aksoy](https://yaksoy.github.io), ACM Transactions on Graphics, 2023 \ [Paper](https://yaksoy.github.io/papers/TOG23-Intrinsic.pdf) | [Video](https://www.youtube.com/watch?v=pWtJd3hqL3c) | [Supplementary](https://yaksoy.github.io/papers/TOG23-Intrinsic-Supp.pdf) | [Data](https://github.com/compphoto/MIDIntrinsics) @@ -11,7 +11,7 @@ This repository contains the code for the following papers: --- -We propose a method for generating high-resolution intrinsic image decompositions, for in-the-wild images. Our method consists of multiple stages. We first estimate a grayscale shading layer using our ordinal shading pipeline. We then estimate low-resolution chromaticity information to account for color illumination effects while maintaining global consistency. Using this initial colorful decomposition, we estimate a high-resolution, sparse albedo layer. We show that our decomposition allows us to train a diffuse shading network using only a single rendered indoor dataset. +We propose a method for generating high-resolution intrinsic image decompositions for in-the-wild images. Our method consists of multiple stages. We first estimate a grayscale shading layer using our ordinal shading pipeline. We then estimate low-resolution chromaticity information to account for colorful illumination effects while maintaining global consistency. Using this initial colorful decomposition, we estimate a high-resolution, sparse albedo layer. We show that our decomposition allows us to train a diffuse shading estimation network using only a single rendered indoor dataset. ![representative](./figures/representative.png) @@ -38,11 +38,10 @@ This will allow you to import the repository as a Python package, and use our pi ## Inference To run our pipeline on your own images you can use the decompose script: ```python -from chrislib.general import view, tile_imgs, view_scale, uninvert +from chrislib.general import uninvert from chrislib.data_util import load_image -from intrinsic.pipeline import run_pipeline -from intrinsic.model_util import load_models +from intrinsic.pipeline import load_models, run_pipeline # load the models from the given paths models = load_models('final_weights.pt') @@ -51,12 +50,7 @@ models = load_models('final_weights.pt') image = load_image('/path/to/input/image') # run the model on the image using R_0 resizing -results = run_pipeline( - models, - image, - resize_conf=0.0, - maintain_size=True -) +results = run_pipeline(models, image) albedo = results['albedo'] inv_shd = results['inv_shading'] @@ -65,16 +59,31 @@ inv_shd = results['inv_shading'] shading = uninvert(inv_shd) ``` -This will run our pipeline and output the linear albedo and shading. You can run this in your browser as well! [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/compphoto/Intrinsic/blob/main/intrinsic_inference.ipynb) +This will run our pipeline and output the linear intrinsic components. You can run this in your browser as well! [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/compphoto/Intrinsic/blob/main/intrinsic_inference.ipynb) ## Citation ``` +@ARTICLE{careagaColorful, + author={Chris Careaga and Ya\u{g}{\i}z Aksoy}, + title={Colorful Diffuse Intrinsic Image Decomposition in the Wild}, + journal={ACM Trans. Graph.}, + year={2024}, + volume = {43}, + number = {6}, + articleno = {178}, + numpages = {12}, +} + @ARTICLE{careagaIntrinsic, - author={Chris Careaga and Ya\u{g}{\i}z Aksoy}, - title={Intrinsic Image Decomposition via Ordinal Shading}, - journal={ACM Trans. Graph.}, - year={2023}, + author={Chris Careaga and Ya\u{g}{\i}z Aksoy}, + title={Intrinsic Image Decomposition via Ordinal Shading}, + journal={ACM Trans. Graph.}, + year={2023}, + volume = {43}, + number = {1}, + articleno = {12}, + numpages = {24}, } ``` diff --git a/intrinsic/model_util.py b/intrinsic/model_util.py deleted file mode 100644 index 23945d9..0000000 --- a/intrinsic/model_util.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -from altered_midas.midas_net import MidasNet -from altered_midas.midas_net_custom import MidasNet_small - -def load_models(path, device='cuda'): - """Load the ordinal network and the intrinsic decomposition network - into a dictionary that can be used to run our pipeline - - params: - path (str or list): the path to the combined weights file, or to each individual weights file (ordinal first, then iid) - device (str) optional: the device to run the model on (default "cuda") - - returns: - models (dict): a dict with the following structure: { - "ordinal_model": altered_midas.midas_net.MidasNet, - "real_model": altered_midas.midas_net_custom.MidasNet_small} - """ - models = {} - - if isinstance(path, list): - ord_state_dict = torch.load(path[0]) - iid_state_dict = torch.load(path[1]) - else: - if path == 'paper_weights': - combined_dict = torch.hub.load_state_dict_from_url('https://github.com/compphoto/Intrinsic/releases/download/v1.0/final_weights.pt', map_location=device, progress=True) - elif path == 'rendered_only': - combined_dict = torch.hub.load_state_dict_from_url('https://github.com/compphoto/Intrinsic/releases/download/v1.0/rendered_only_weights.pt', map_location=device, progress=True) - else: - combined_dict = torch.load(path) - - ord_state_dict = combined_dict['ord_state_dict'] - iid_state_dict = combined_dict['iid_state_dict'] - - ord_model = MidasNet() - ord_model.load_state_dict(ord_state_dict) - ord_model.eval() - ord_model = ord_model.to(device) - - iid_model = MidasNet_small(exportable=False, input_channels=5, output_channels=1) - iid_model.load_state_dict(iid_state_dict) - iid_model.eval() - iid_model = iid_model.to(device) - - models['ordinal_model'] = ord_model - models['real_model'] = iid_model - - return models - diff --git a/intrinsic/pipeline.py b/intrinsic/pipeline.py index 8b5aae1..c56bd00 100644 --- a/intrinsic/pipeline.py +++ b/intrinsic/pipeline.py @@ -3,15 +3,114 @@ from skimage.transform import resize from chrislib.resolution_util import optimal_resize -from chrislib.general import round_32, uninvert +from chrislib.general import round_32, uninvert, invert, get_brightness, to2np +from chrislib.color_util import batch_rgb2iuv, batch_iuv2rgb from intrinsic.ordinal_util import base_resize, equalize_predictions +from altered_midas.midas_net import MidasNet +from altered_midas.midas_net_custom import MidasNet_small -def run_pipeline( +STAGE_DICT = { + # 'ordinal': 0, + 'gray': 1, + 'chroma': 2, + 'albedo': 3, + 'diffuse': 4 +} + +V1_DICT = { + 'paper_weights' : 'https://github.com/compphoto/Intrinsic/releases/download/v1.0/final_weights.pt', + 'rendered_only' : 'https://github.com/compphoto/Intrinsic/releases/download/v1.0/rendered_only_weights.pt' +} + +def load_models(path, stage=4, device='cuda'): + """The networks as part of our intrinsic decomposition pipeline. Since the pipeline consists of stages, + can load the models up to a specific stage in the pipeline + + params: + paths (str or list): the paths to each of the models in the pipeline, or a name for released weights + stage (int or str) optional: the stage to load the models up to (1-4) (default 4) + if string must be one of the following: "gray", "chroma", "albedo", "diffuse" + device (str) optional: the device to run the model on (default "cuda") + + returns: + models (dict): a dict with the following structure: { + "ordinal_model": altered_midas.midas_net.MidasNet, + "iid_model": altered_midas.midas_net_custom.MidasNet_small, + "col_model": altered_midas.midas_net_custom.MidasNet, + "alb_model": altered_midas.midas_net_custom.MidasNet, + "dif_model": altered_midas.midas_net_custom.MidasNet + } + """ + models = {} + + if isinstance(stage, str): + stage = STAGE_DICT[stage] + + # if the path is a string, we are loading a release of the pipeline + if isinstance(path, str): + if path in ['paper_weights', 'rendered_only']: + # these are V1 releases from the ordinal shading paper, so set the stage to 1 to only run grayscale + combined_dict = torch.hub.load_state_dict_from_url(V1_DICT[path], map_location=device, progress=True) + stage = 1 + + ord_state_dict = combined_dict['ord_state_dict'] + iid_state_dict = combined_dict['iid_state_dict'] + else: + # TODO: otherwise we are loading the colorful version of the pipeline which has different logic + pass + + + elif isinstance(path, list): + + ord_state_dict = torch.load(path[0]) + iid_state_dict = torch.load(path[1]) + + if stage >= 2: col_state_dict = torch.load(path[2]) + if stage >= 3: alb_state_dict = torch.load(path[3]) + if stage >= 4: dif_state_dict = torch.load(path[4]) + + ord_model = MidasNet() + ord_model.load_state_dict(ord_state_dict) + ord_model.eval() + ord_model = ord_model.to(device) + models['ord_model'] = ord_model + + iid_model = MidasNet_small(exportable=False, input_channels=5, output_channels=1) + iid_model.load_state_dict(iid_state_dict) + iid_model.eval() + iid_model = iid_model.to(device) + models['iid_model'] = iid_model + + if stage >= 2: + col_model = MidasNet(activation='sigmoid', input_channels=7, output_channels=2) + col_model.load_state_dict(col_state_dict) + col_model.eval() + col_model = col_model.to(device) + models['col_model'] = col_model + + if stage >= 3: + alb_model = MidasNet(activation='sigmoid', input_channels=9, output_channels=3) + alb_model.load_state_dict(alb_state_dict) + alb_model.eval() + alb_model = alb_model.to(device) + models['alb_model'] = alb_model + + if stage >= 4: + dif_model = MidasNet(activation='sigmoid', input_channels=9, output_channels=3) + dif_model.load_state_dict(dif_state_dict) + dif_model.eval() + dif_model = dif_model.to(device) + models['dif_model'] = dif_model + + + return models + + +def run_gray_pipeline( models, img_arr, - output_ordinal=False, resize_conf=0.0, base_size=384, maintain_size=False, @@ -19,13 +118,11 @@ def run_pipeline( device='cuda', lstsq_p=0.0, inputs='all'): - """Runs the complete pipeline for shading and albedo prediction + """Runs the complete pipeline for grayscale shading and albedo prediction params: - models (dict): models dictionary returned by model_util.load_models() + models (dict): models dictionary returned by load_models() img_arr (np.array): RGB input image as numpy array between 0-1 - output_ordinal (bool) optional: whether or not to output intermediate ordinal estimations - (default False) resize_conf (float) optional: confidence to use for resizing (between 0-1) if None maintain original size (default None) base_size (int) optional: size of the base resolution estimation (default 384) @@ -39,7 +136,7 @@ def run_pipeline( always included (default "all") returns: - results (dict): a result dictionary with albedo, shading and potentiall ordinal estimations + results (dict): a result dictionary with albedo, shading and potentially ordinal estimations """ results = {} @@ -80,8 +177,8 @@ def run_pipeline( base_input = torch.from_numpy(base_input).permute(2, 0, 1).to(device).float() full_input = torch.from_numpy(full_input).permute(2, 0, 1).to(device).float() - base_out = models['ordinal_model'](base_input.unsqueeze(0)).squeeze(0) - full_out = models['ordinal_model'](full_input.unsqueeze(0)).squeeze(0) + base_out = models['ord_model'](base_input.unsqueeze(0)).squeeze(0) + full_out = models['ord_model'](full_input.unsqueeze(0)).squeeze(0) # the ordinal estimations come out of the model with a channel dim base_out = base_out.permute(1, 2, 0).cpu().numpy() @@ -102,6 +199,7 @@ def run_pipeline( fll = torch.from_numpy(ord_full).permute(2, 0, 1).to(device) # combine the base and full ordinal estimations w/ the input image + # NOTE: this is just for ablation studies provided in the paper if inputs == 'full': combined = torch.cat((inp, fll), 0).unsqueeze(0) elif inputs == 'base': @@ -111,7 +209,7 @@ def run_pipeline( else: combined = torch.cat((inp, bse, fll), 0).unsqueeze(0) - inv_shd = models['real_model'](combined).squeeze(1) + inv_shd = models['iid_model'](combined).squeeze(1) # the shading comes out in the inverse space so undo it shd = uninvert(inv_shd) @@ -123,19 +221,137 @@ def run_pipeline( alb = alb.permute(1, 2, 0).detach().cpu().numpy() if maintain_size: - if output_ordinal: - ord_base = resize(base_out, (orig_h, orig_w), anti_aliasing=True) - ord_full = resize(full_out, (orig_h, orig_w), anti_aliasing=True) + ord_base = resize(base_out, (orig_h, orig_w), anti_aliasing=True) + ord_full = resize(full_out, (orig_h, orig_w), anti_aliasing=True) inv_shd = resize(inv_shd, (orig_h, orig_w), anti_aliasing=True) alb = resize(alb, (orig_h, orig_w), anti_aliasing=True) - if output_ordinal: - results['ord_full'] = ord_full - results['ord_base'] = ord_base - results['inv_shading'] = inv_shd - results['albedo'] = alb + results['ord_full'] = ord_full + results['ord_base'] = ord_base + + results['gry_shd'] = inv_shd + results['gry_alb'] = alb results['image'] = img_arr + results['lin_img'] = lin_img return results + +def run_pipeline(models, img_arr, stage=4, resize_conf=0.0, base_size=384, linear=False, device='cuda'): + + results = run_gray_pipeline( + models, + img_arr, + resize_conf=resize_conf, + linear=linear, + device=device, + base_size=base_size, + ) + + if stage == 1: + return results + + img = results['lin_img'] + gry_shd = results['gry_shd'][:, :, None] + gry_alb = results['gry_alb'] + + # pytorch versions of the input, gry shd and albedo with channel dim + net_img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device) + net_shd = torch.from_numpy(gry_shd).permute(2, 0, 1).unsqueeze(0).to(device) + net_alb = torch.from_numpy(gry_alb).permute(2, 0, 1).unsqueeze(0).to(device) + + in_img_luv = batch_rgb2iuv(net_img) + in_alb_luv = batch_rgb2iuv(net_alb) + + orig_sz = img.shape[:2] + scale = base_size / max(orig_sz) + base_sz = (round_32(orig_sz[0] * scale), round_32(orig_sz[1] * scale)) + + # we want to resize the inputs to base resolution + in_img_luv = torch.nn.functional.interpolate(in_img_luv, size=base_sz, mode='bilinear', align_corners=True, antialias=True) + in_alb_luv = torch.nn.functional.interpolate(in_alb_luv, size=base_sz, mode='bilinear', align_corners=True, antialias=True) + in_gry_shd = torch.nn.functional.interpolate(net_shd, size=base_sz, mode='bilinear', align_corners=True, antialias=True) + + inp = torch.cat([in_img_luv, in_gry_shd, in_alb_luv], 1) + + # this is the shading color components, N x 2 x H x W + with torch.no_grad(): + uv_shd = models['col_model'](inp) + + # resize the low res shd chroma back to original size + uv_shd = torch.nn.functional.interpolate(uv_shd, size=orig_sz, mode='bilinear', align_corners=True) + + # now combine gry shd with chroma in channel dim and convert to rgb + iuv_shd = torch.cat((net_shd, uv_shd), 1) + rough_shd = batch_iuv2rgb(iuv_shd) + rough_alb = net_img / rough_shd + + rough_alb *= 0.75 / torch.quantile(rough_alb, 0.99) + rough_alb = rough_alb.clip(0.001) + rough_shd = net_img / rough_alb + + # convert the low-res chroma decomposition to numpy in case we return early + lr_clr = to2np(batch_iuv2rgb(torch.cat((torch.ones_like(net_shd) * 0.6, uv_shd), 1)).squeeze(0)) + lr_alb = to2np(rough_alb.squeeze(0)) + lr_shd = to2np(rough_shd.squeeze(0)) + wb_img = (lr_alb * get_brightness(lr_shd)).clip(0, 1) + + results['lr_clr'] = lr_clr + results['lr_alb'] = lr_alb + results['lr_shd'] = lr_shd + results['wb_img'] = wb_img + + if stage == 2: + return results + + # albedo estimation net gets img, inverted rgb shd and implied alb + inp = torch.cat([net_img, invert(rough_shd), rough_alb], 1) + + with torch.no_grad(): + pred_alb = models['alb_model'](inp) + + net_clr_shd = net_img / pred_alb.clip(1e-3) + + # convert high-res albedo and shading to numpy + hr_alb = pred_alb.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() + hr_shd = img / hr_alb.clip(1e-3) + hr_clr = batch_rgb2iuv(net_img / pred_alb.clip(1e-4)) + hr_clr[:, 0, :, :] = torch.ones_like(net_shd) * 0.6 + hr_clr = to2np(batch_iuv2rgb(hr_clr).squeeze(0)) + wb_img = (hr_alb * get_brightness(hr_shd)).clip(0, 1) + + results['hr_alb'] = hr_alb + results['hr_shd'] = hr_shd + results['hr_clr'] = hr_clr + results['wb_img'] = wb_img + + if stage == 3: + return results + + inp = torch.cat([net_img, invert(net_clr_shd), pred_alb], 1) + + with torch.no_grad(): + dif_shd = models['dif_model'](inp) + + dif_shd = uninvert(dif_shd) + + dif_shd = dif_shd.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() + + dif_img = (hr_alb * dif_shd) + res = img - dif_img + + neg_res = res.copy() + neg_res[neg_res > 0] = 0 + neg_res = abs(neg_res) + + pos_res = res.copy() + pos_res[pos_res < 0] = 0 + + results['dif_shd'] = dif_shd + results['dif_img'] = dif_img + results['residual'] = res + results['neg_res'] = neg_res + results['pos_res'] = pos_res + + return results \ No newline at end of file