Skip to content

Commit

Permalink
Merge pull request #91 from chairc/dev
Browse files Browse the repository at this point in the history
Remove two imshow() duplicate functions; Remove magic transforms and eval function; Replace images.shape[0] to batch_size.
  • Loading branch information
chairc authored Sep 20, 2024
2 parents f169fbd + 595f912 commit 4abc63c
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 17 deletions.
3 changes: 0 additions & 3 deletions sr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,6 @@ def train(rank=None, args=None):

# Saving and validating models in the main process
if save_models:
# Val
# model.eval()

# Saving model, set the checkpoint name
save_name = f"ckpt_{str(epoch).zfill(3)}"
# Init ckpt params
Expand Down
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def train(rank=None, args=None):
# Enable visualization
if vis:
# images.shape[0] is the number of images in the current batch
n = num_vis if num_vis > 0 else images.shape[0]
n = num_vis if num_vis > 0 else batch_size
sampled_images = diffusion.sample(model=model, n=n)
save_images(images=sampled_images,
path=os.path.join(results_vis_dir, f"{save_name}.{image_format}"))
Expand Down
2 changes: 1 addition & 1 deletion utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def plot_one_image_in_images(images, fig_size=(64, 64)):
plt.figure(figsize=fig_size)
for i in images.cpu():
plt.imshow(X=i)
plt.imshow()
plt.show()


def save_images(images, path, **kwargs):
Expand Down
12 changes: 0 additions & 12 deletions webui/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,6 @@ def train(self, seed, conditional, sample, network, run_name, epochs, batch_size
results_dir = results_logging[1]
results_vis_dir = results_logging[2]
results_tb_dir = results_logging[3]
# Dataloader
transforms = torchvision.transforms.Compose([
# Resize input size
# torchvision.transforms.Resize(80), image_size + 1/4 * image_size
torchvision.transforms.Resize(size=int(image_size + image_size / 4)),
# Random adjustment cropping
torchvision.transforms.RandomResizedCrop(size=image_size, scale=RANDOM_RESIZED_CROP_SCALE),
# To Tensor Format
torchvision.transforms.ToTensor(),
# For standardization, the mean and standard deviation
torchvision.transforms.Normalize(mean=MEAN, std=STD)
])
# Load the folder data under the current path,
# and automatically divide the labels according to the dataset under each file name
dataloader = get_dataset(image_size=image_size, dataset_path=dataset_path, batch_size=batch_size,
Expand Down

0 comments on commit 4abc63c

Please sign in to comment.