-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CodeCamp2023-652] MMagic 新 config 体验与适配 StyleGAN3 (#2018)
* CodeCamp2023_652_2 * CodeCamp2023_652_2 --------- Co-authored-by: LeoXing1996 <[email protected]> Co-authored-by: rangoliu <[email protected]>
- Loading branch information
1 parent
d140802
commit e1bd41a
Showing
17 changed files
with
1,001 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler | ||
|
||
from mmagic.datasets.basic_image_dataset import BasicImageDataset | ||
from mmagic.datasets.transforms.aug_shape import Flip | ||
from mmagic.datasets.transforms.formatting import PackInputs | ||
from mmagic.datasets.transforms.loading import LoadImageFromFile | ||
|
||
dataset_type = BasicImageDataset | ||
|
||
train_pipeline = [ | ||
dict(type=LoadImageFromFile, key='gt'), | ||
dict(type=Flip, keys=['gt'], direction='horizontal'), | ||
dict(type=PackInputs, keys='gt') | ||
] | ||
|
||
val_pipeline = [ | ||
dict(type=LoadImageFromFile, key='gt'), | ||
dict(type=PackInputs, keys=['gt']) | ||
] | ||
|
||
# `batch_size` and `data_root` need to be set. | ||
train_dataloader = dict( | ||
batch_size=4, | ||
num_workers=8, | ||
persistent_workers=True, | ||
sampler=dict(type=InfiniteSampler, shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix=dict(gt=''), | ||
data_root=None, # set by user | ||
pipeline=train_pipeline)) | ||
|
||
val_dataloader = dict( | ||
batch_size=4, | ||
num_workers=8, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix=dict(gt=''), | ||
data_root=None, # set by user | ||
pipeline=val_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
persistent_workers=True) | ||
|
||
test_dataloader = dict( | ||
batch_size=4, | ||
num_workers=8, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix=dict(gt=''), | ||
data_root=None, # set by user | ||
pipeline=val_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
persistent_workers=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
mmagic/configs/_base_/datasets/unconditional_imgs_flip_512x512.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.dataset.sampler import DefaultSampler, InfiniteSampler | ||
|
||
from mmagic.datasets.basic_image_dataset import BasicImageDataset | ||
from mmagic.datasets.transforms.aug_shape import Flip, Resize | ||
from mmagic.datasets.transforms.formatting import PackInputs | ||
from mmagic.datasets.transforms.loading import LoadImageFromFile | ||
|
||
dataset_type = BasicImageDataset | ||
|
||
# TODO: | ||
train_pipeline = [ | ||
dict(type=LoadImageFromFile, key='gt'), | ||
dict(type=Resize, keys='gt', scale=(512, 512)), | ||
dict(type=Flip, keys=['gt'], direction='horizontal'), # TODO: | ||
dict(type=PackInputs) | ||
] | ||
|
||
# `batch_size` and `data_root` need to be set. | ||
train_dataloader = dict( | ||
batch_size=None, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type=InfiniteSampler, shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix=dict(gt=''), | ||
data_root=None, # set by user | ||
pipeline=train_pipeline)) | ||
|
||
val_dataloader = dict( | ||
batch_size=None, | ||
num_workers=4, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix=dict(gt=''), | ||
data_root=None, # set by user | ||
pipeline=train_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
persistent_workers=True) | ||
|
||
test_dataloader = dict( | ||
batch_size=None, | ||
num_workers=4, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix=dict(gt=''), | ||
data_root=None, # set by user | ||
pipeline=train_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
persistent_workers=True) |
48 changes: 48 additions & 0 deletions
48
mmagic/configs/_base_/datasets/unconditional_imgs_flip_lanczos_resize_256x256.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
dataset_type = 'BasicImageDataset' | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile', key='gt'), | ||
dict( | ||
type='Resize', | ||
keys='gt', | ||
scale=(256, 256), | ||
interpolation='lanczos', | ||
backend='pillow'), | ||
dict(type='Flip', keys=['gt'], direction='horizontal'), | ||
dict(type='PackInputs') | ||
] | ||
|
||
# `batch_size` and `data_root` need to be set. | ||
train_dataloader = dict( | ||
batch_size=None, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type='InfiniteSampler', shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix=dict(gt=''), | ||
data_root=None, # set by user | ||
pipeline=train_pipeline)) | ||
|
||
val_dataloader = dict( | ||
batch_size=None, | ||
num_workers=4, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix=dict(gt=''), | ||
data_root=None, # set by user | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
persistent_workers=True) | ||
|
||
test_dataloader = dict( | ||
batch_size=None, | ||
num_workers=4, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix=dict(gt=''), | ||
data_root=None, # set by user | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
persistent_workers=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# define GAN model | ||
from mmagic.models.base_models.average_model import ExponentialMovingAverage | ||
from mmagic.models.data_preprocessors import DataPreprocessor | ||
from mmagic.models.editors.stylegan2 import StyleGAN2Discriminator | ||
from mmagic.models.editors.stylegan3 import StyleGAN3, StyleGAN3Generator | ||
|
||
d_reg_interval = 16 | ||
g_reg_interval = 4 | ||
|
||
g_reg_ratio = g_reg_interval / (g_reg_interval + 1) | ||
d_reg_ratio = d_reg_interval / (d_reg_interval + 1) | ||
|
||
model = dict( | ||
type=StyleGAN3, | ||
data_preprocessor=dict(type=DataPreprocessor), | ||
generator=dict( | ||
type=StyleGAN3Generator, # StyleGANv3Generator | ||
noise_size=512, | ||
style_channels=512, | ||
out_size=None, # Need to be set. | ||
img_channels=3, | ||
), | ||
discriminator=dict( | ||
type=StyleGAN2Discriminator, | ||
in_size=None, # Need to be set. | ||
), | ||
ema_config=dict(type=ExponentialMovingAverage), | ||
loss_config=dict( | ||
r1_loss_weight=10. / 2. * d_reg_interval, | ||
r1_interval=d_reg_interval, | ||
norm_mode='HWC', | ||
g_reg_interval=g_reg_interval, | ||
g_reg_weight=2. * g_reg_interval, | ||
pl_batch_shrink=2)) |
Oops, something went wrong.