From 828a8422a83fdf664339b8ebeb3b4a2f377e7307 Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Wed, 18 Sep 2024 08:08:47 +0800 Subject: [PATCH 1/3] Fix: Remove two imshow() duplicate functions. --- utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/utils.py b/utils/utils.py index a0cd7ca..bfa6241 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -43,7 +43,7 @@ def plot_one_image_in_images(images, fig_size=(64, 64)): plt.figure(figsize=fig_size) for i in images.cpu(): plt.imshow(X=i) - plt.imshow() + plt.show() def save_images(images, path, **kwargs): From 95f4bf1a0fe698c2a68ebab845dbc634d3ac2de7 Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Fri, 20 Sep 2024 22:48:38 +0800 Subject: [PATCH 2/3] Update: Remove magic transforms and eval function. --- sr/train.py | 3 --- webui/web.py | 12 ------------ 2 files changed, 15 deletions(-) diff --git a/sr/train.py b/sr/train.py index a3ff9a6..92fe37f 100644 --- a/sr/train.py +++ b/sr/train.py @@ -257,9 +257,6 @@ def train(rank=None, args=None): # Saving and validating models in the main process if save_models: - # Val - # model.eval() - # Saving model, set the checkpoint name save_name = f"ckpt_{str(epoch).zfill(3)}" # Init ckpt params diff --git a/webui/web.py b/webui/web.py index 75144e6..f87d71d 100644 --- a/webui/web.py +++ b/webui/web.py @@ -102,18 +102,6 @@ def train(self, seed, conditional, sample, network, run_name, epochs, batch_size results_dir = results_logging[1] results_vis_dir = results_logging[2] results_tb_dir = results_logging[3] - # Dataloader - transforms = torchvision.transforms.Compose([ - # Resize input size - # torchvision.transforms.Resize(80), image_size + 1/4 * image_size - torchvision.transforms.Resize(size=int(image_size + image_size / 4)), - # Random adjustment cropping - torchvision.transforms.RandomResizedCrop(size=image_size, scale=RANDOM_RESIZED_CROP_SCALE), - # To Tensor Format - torchvision.transforms.ToTensor(), - # For standardization, the mean and standard deviation - torchvision.transforms.Normalize(mean=MEAN, std=STD) - ]) # Load the folder data under the current path, # and automatically divide the labels according to the dataset under each file name dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size, From 595f9126fb9b2128dcffb7db989e457687bbb5a0 Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Fri, 20 Sep 2024 22:49:19 +0800 Subject: [PATCH 3/3] Update: Replace images.shape[0] to batch_size. --- tools/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/train.py b/tools/train.py index 5dd8d6a..5ce1fc0 100644 --- a/tools/train.py +++ b/tools/train.py @@ -265,7 +265,7 @@ def train(rank=None, args=None): # Enable visualization if vis: # images.shape[0] is the number of images in the current batch - n = num_vis if num_vis > 0 else images.shape[0] + n = num_vis if num_vis > 0 else batch_size sampled_images = diffusion.sample(model=model, n=n) save_images(images=sampled_images, path=os.path.join(results_vis_dir, f"{save_name}.{image_format}"))