Skip to content

Commit

Permalink
Merge pull request #93 from chairc/dev
Browse files Browse the repository at this point in the history
Add banner and version information.
  • Loading branch information
chairc authored Oct 20, 2024
2 parents 4abc63c + 61a970a commit e6d3e26
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 3 deletions.
3 changes: 2 additions & 1 deletion config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .choices import bool_choices, sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \
image_format_choices, noise_schedule_choices
from .setting import MASTER_ADDR, MASTER_PORT, EMA_BETA, RANDOM_RESIZED_CROP_SCALE, MEAN, STD
from .version import __version__, get_versions, get_latest_version, get_old_versions, check_version_is_latest
from .version import __version__, get_versions, get_latest_version, get_old_versions, check_version_is_latest, \
get_version_banner
8 changes: 8 additions & 0 deletions config/banner.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_____ _
| __ \ (_)
| |__) | _ _ __ _ __ _ _ __ __ _
| _ / | | | '_ \| '_ \| | '_ \ / _` |
| | \ \ |_| | | | | | | | | | | | (_| | _ _ _
|_| \_\__,_|_| |_|_| |_|_|_| |_|\__, | (_) (_) (_)
__/ |
|___/
14 changes: 14 additions & 0 deletions config/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,21 @@ def check_version_is_latest(current_version):
return False


def get_version_banner():
"""
Get version banner.
"""
with open(file="../config/banner.txt", mode="r", encoding="utf-8") as banner_file:
contents = banner_file.read()
print(contents)
print(f"===============IDDM version: {get_latest_version()}===============\n"
"Project Author : chairc\n"
"Project GitHub : https://github.com/chairc/Integrated-Design-Diffusion-Model")
banner_file.close()


if __name__ == "__main__":
get_versions()
get_latest_version()
get_old_versions()
get_version_banner()
3 changes: 3 additions & 0 deletions sr/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from PIL import Image

sys.path.append(os.path.dirname(sys.path[0]))
from config.version import get_version_banner
from sr.interface import inference, load_sr_model
from utils.initializer import device_initializer
from utils.utils import plot_images, save_images, check_and_create_dir
Expand Down Expand Up @@ -81,4 +82,6 @@ def lr2hr(args):
parser.add_argument("--result_path", type=str, default="/your/path/Diffusion-Model/result")

args = parser.parse_args()
# Get version banner
get_version_banner()
lr2hr(args)
4 changes: 3 additions & 1 deletion sr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
sys.path.append(os.path.dirname(sys.path[0]))
from config.choices import loss_func_choices, sr_network_choices, optim_choices
from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA
from config.version import get_version_banner
from model.modules.ema import EMA
from utils.initializer import device_initializer, seed_initializer, sr_network_initializer, optimizer_initializer, \
lr_initializer, amp_initializer, loss_initializer
Expand Down Expand Up @@ -379,5 +380,6 @@ def main(args):
parser.add_argument("--world_size", type=int, default=2)

args = parser.parse_args()

# Get version banner
get_version_banner()
main(args)
3 changes: 3 additions & 0 deletions tools/FID_calculator_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_fid.inception import InceptionV3

sys.path.append(os.path.dirname(sys.path[0]))
from config.version import get_version_banner
from utils.initializer import device_initializer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,4 +69,6 @@ def main(args):
# Set the use GPU in normal training (required)
parser.add_argument("--use_gpu", type=int, default=0)
args = parser.parse_args()
# Get version banner
get_version_banner()
main(args)
3 changes: 3 additions & 0 deletions tools/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

sys.path.append(os.path.dirname(sys.path[0]))
from config.choices import sample_choices, network_choices, act_choices, image_format_choices, parse_image_size_type
from config.version import get_version_banner
from utils.check import check_image_size
from utils.initializer import device_initializer, network_initializer, sample_initializer, generate_initializer
from utils.utils import plot_images, save_images, save_one_image_in_images, check_and_create_dir
Expand Down Expand Up @@ -165,4 +166,6 @@ def generate(args):
parser.add_argument("--num_classes", type=int, default=10)

args = parser.parse_args()
# Get version banner
get_version_banner()
generate(args)
4 changes: 3 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from config.choices import sample_choices, network_choices, optim_choices, act_choices, lr_func_choices, \
image_format_choices, noise_schedule_choices, parse_image_size_type, loss_func_choices
from config.setting import MASTER_ADDR, MASTER_PORT, EMA_BETA
from config.version import get_version_banner
from model.modules.ema import EMA
from utils.check import check_image_size
from utils.dataset import get_dataset
Expand Down Expand Up @@ -431,5 +432,6 @@ def main(args):
parser.add_argument("--cfg_scale", type=int, default=3)

args = parser.parse_args()

# Get version banner
get_version_banner()
main(args)

0 comments on commit e6d3e26

Please sign in to comment.