Skip to content

Commit

Permalink
[CodeCamp2023-652] MMagic 新 config 体验与适配 StyleGAN3 (#2018)
Browse files Browse the repository at this point in the history
* CodeCamp2023_652_2

* CodeCamp2023_652_2

---------

Co-authored-by: LeoXing1996 <[email protected]>
Co-authored-by: rangoliu <[email protected]>
  • Loading branch information
3 people authored Sep 12, 2023
1 parent d140802 commit e1bd41a
Show file tree
Hide file tree
Showing 17 changed files with 1,001 additions and 17 deletions.
15 changes: 10 additions & 5 deletions mmagic/configs/_base_/datasets/cifar10_noaug.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
cifar_pipeline = [dict(type='PackInputs')]
from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler

from mmagic.datasets.cifar10_dataset import CIFAR10
from mmagic.datasets.transforms.formatting import PackInputs

cifar_pipeline = [dict(type=PackInputs)]
cifar_dataset = dict(
type='CIFAR10',
type=CIFAR10,
data_root='./data',
data_prefix='cifar10',
test_mode=False,
Expand All @@ -10,19 +15,19 @@
train_dataloader = dict(
num_workers=2,
dataset=cifar_dataset,
sampler=dict(type='InfiniteSampler', shuffle=True),
sampler=dict(type=InfiniteSampler, shuffle=True),
persistent_workers=True)

val_dataloader = dict(
batch_size=32,
num_workers=2,
dataset=cifar_dataset,
sampler=dict(type='DefaultSampler', shuffle=False),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)

test_dataloader = dict(
batch_size=32,
num_workers=2,
dataset=cifar_dataset,
sampler=dict(type='DefaultSampler', shuffle=False),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)
54 changes: 54 additions & 0 deletions mmagic/configs/_base_/datasets/ffhq_flip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler

from mmagic.datasets.basic_image_dataset import BasicImageDataset
from mmagic.datasets.transforms.aug_shape import Flip
from mmagic.datasets.transforms.formatting import PackInputs
from mmagic.datasets.transforms.loading import LoadImageFromFile

dataset_type = BasicImageDataset

train_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=Flip, keys=['gt'], direction='horizontal'),
dict(type=PackInputs, keys='gt')
]

val_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=PackInputs, keys=['gt'])
]

# `batch_size` and `data_root` need to be set.
train_dataloader = dict(
batch_size=4,
num_workers=8,
persistent_workers=True,
sampler=dict(type=InfiniteSampler, shuffle=True),
dataset=dict(
type=dataset_type,
data_prefix=dict(gt=''),
data_root=None, # set by user
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=4,
num_workers=8,
dataset=dict(
type=dataset_type,
data_prefix=dict(gt=''),
data_root=None, # set by user
pipeline=val_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)

test_dataloader = dict(
batch_size=4,
num_workers=8,
dataset=dict(
type=dataset_type,
data_prefix=dict(gt=''),
data_root=None, # set by user
pipeline=val_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)
30 changes: 19 additions & 11 deletions mmagic/configs/_base_/datasets/imagenet_noaug_128.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
# dataset settings
dataset_type = 'ImageNet'
from mmengine.dataset.sampler import DefaultSampler

from mmagic.datasets.imagenet_dataset import ImageNet
from mmagic.datasets.transforms.aug_shape import Resize
from mmagic.datasets.transforms.crop import CenterCropLongEdge
from mmagic.datasets.transforms.formatting import PackInputs
from mmagic.datasets.transforms.loading import LoadImageFromFile

dataset_type = ImageNet

# different from mmcls, we adopt the setting used in BigGAN.
# Remove `RandomFlip` augmentation and change `RandomCropLongEdge` to
# `CenterCropLongEdge` to eliminate randomness.
# dataset settings
train_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='CenterCropLongEdge'),
dict(type='Resize', scale=(128, 128), backend='pillow'),
dict(type='PackInputs')
dict(type=LoadImageFromFile, key='img'),
dict(type=CenterCropLongEdge),
dict(type=Resize, scale=(128, 128), backend='pillow'),
dict(type=PackInputs)
]

test_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='CenterCropLongEdge'),
dict(type='Resize', scale=(128, 128), backend='pillow'),
dict(type='PackInputs')
dict(type=LoadImageFromFile, key='img'),
dict(type=CenterCropLongEdge),
dict(type=Resize, scale=(128, 128), backend='pillow'),
dict(type=PackInputs)
]

train_dataloader = dict(
Expand All @@ -29,7 +37,7 @@
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
sampler=dict(type=DefaultSampler, shuffle=True),
persistent_workers=True)

val_dataloader = dict(
Expand All @@ -41,7 +49,7 @@
ann_file='meta/train.txt',
data_prefix='train',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)

test_dataloader = val_dataloader
51 changes: 51 additions & 0 deletions mmagic/configs/_base_/datasets/unconditional_imgs_flip_512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler

from mmagic.datasets.basic_image_dataset import BasicImageDataset
from mmagic.datasets.transforms.aug_shape import Flip, Resize
from mmagic.datasets.transforms.formatting import PackInputs
from mmagic.datasets.transforms.loading import LoadImageFromFile

dataset_type = BasicImageDataset

# TODO:
train_pipeline = [
dict(type=LoadImageFromFile, key='gt'),
dict(type=Resize, keys='gt', scale=(512, 512)),
dict(type=Flip, keys=['gt'], direction='horizontal'), # TODO:
dict(type=PackInputs)
]

# `batch_size` and `data_root` need to be set.
train_dataloader = dict(
batch_size=None,
num_workers=4,
persistent_workers=True,
sampler=dict(type=InfiniteSampler, shuffle=True),
dataset=dict(
type=dataset_type,
data_prefix=dict(gt=''),
data_root=None, # set by user
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=None,
num_workers=4,
dataset=dict(
type=dataset_type,
data_prefix=dict(gt=''),
data_root=None, # set by user
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)

test_dataloader = dict(
batch_size=None,
num_workers=4,
dataset=dict(
type=dataset_type,
data_prefix=dict(gt=''),
data_root=None, # set by user
pipeline=train_pipeline),
sampler=dict(type=DefaultSampler, shuffle=False),
persistent_workers=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
dataset_type = 'BasicImageDataset'

train_pipeline = [
dict(type='LoadImageFromFile', key='gt'),
dict(
type='Resize',
keys='gt',
scale=(256, 256),
interpolation='lanczos',
backend='pillow'),
dict(type='Flip', keys=['gt'], direction='horizontal'),
dict(type='PackInputs')
]

# `batch_size` and `data_root` need to be set.
train_dataloader = dict(
batch_size=None,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_prefix=dict(gt=''),
data_root=None, # set by user
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=None,
num_workers=4,
dataset=dict(
type=dataset_type,
data_prefix=dict(gt=''),
data_root=None, # set by user
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True)

test_dataloader = dict(
batch_size=None,
num_workers=4,
dataset=dict(
type=dataset_type,
data_prefix=dict(gt=''),
data_root=None, # set by user
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True)
35 changes: 35 additions & 0 deletions mmagic/configs/_base_/models/base_styleganv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
# define GAN model
from mmagic.models.base_models.average_model import ExponentialMovingAverage
from mmagic.models.data_preprocessors import DataPreprocessor
from mmagic.models.editors.stylegan2 import StyleGAN2Discriminator
from mmagic.models.editors.stylegan3 import StyleGAN3, StyleGAN3Generator

d_reg_interval = 16
g_reg_interval = 4

g_reg_ratio = g_reg_interval / (g_reg_interval + 1)
d_reg_ratio = d_reg_interval / (d_reg_interval + 1)

model = dict(
type=StyleGAN3,
data_preprocessor=dict(type=DataPreprocessor),
generator=dict(
type=StyleGAN3Generator, # StyleGANv3Generator
noise_size=512,
style_channels=512,
out_size=None, # Need to be set.
img_channels=3,
),
discriminator=dict(
type=StyleGAN2Discriminator,
in_size=None, # Need to be set.
),
ema_config=dict(type=ExponentialMovingAverage),
loss_config=dict(
r1_loss_weight=10. / 2. * d_reg_interval,
r1_interval=d_reg_interval,
norm_mode='HWC',
g_reg_interval=g_reg_interval,
g_reg_weight=2. * g_reg_interval,
pl_batch_shrink=2))
Loading

0 comments on commit e1bd41a

Please sign in to comment.