From 82b6f1a45563c9156f2a3d43b3364d8abca80559 Mon Sep 17 00:00:00 2001 From: HubHop Date: Tue, 28 Mar 2023 21:00:43 +1100 Subject: [PATCH] add code for levits --- .gitignore | 4 + README.md | 1 + stitching_levit/.gitignore | 113 +++++ stitching_levit/LICENSE | 201 ++++++++ stitching_levit/README.md | 57 +++ stitching_levit/datasets.py | 120 +++++ stitching_levit/engine.py | 180 ++++++++ stitching_levit/hubconf.py | 6 + stitching_levit/levit.py | 560 +++++++++++++++++++++++ stitching_levit/levit_c.py | 457 ++++++++++++++++++ stitching_levit/losses.py | 66 +++ stitching_levit/main.py | 416 +++++++++++++++++ stitching_levit/pretrained/readme.txt | 1 + stitching_levit/results/stitches_res.txt | 8 + stitching_levit/run_with_submitit.py | 137 ++++++ stitching_levit/samplers.py | 63 +++ stitching_levit/snnet.py | 136 ++++++ stitching_levit/utils.py | 402 ++++++++++++++++ 18 files changed, 2928 insertions(+) create mode 100755 stitching_levit/.gitignore create mode 100755 stitching_levit/LICENSE create mode 100755 stitching_levit/README.md create mode 100755 stitching_levit/datasets.py create mode 100755 stitching_levit/engine.py create mode 100755 stitching_levit/hubconf.py create mode 100755 stitching_levit/levit.py create mode 100755 stitching_levit/levit_c.py create mode 100755 stitching_levit/losses.py create mode 100755 stitching_levit/main.py create mode 100644 stitching_levit/pretrained/readme.txt create mode 100644 stitching_levit/results/stitches_res.txt create mode 100755 stitching_levit/run_with_submitit.py create mode 100755 stitching_levit/samplers.py create mode 100755 stitching_levit/snnet.py create mode 100755 stitching_levit/utils.py diff --git a/.gitignore b/.gitignore index c3d05ad..1e67c70 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,7 @@ dmypy.json # Pyre type checker .pyre/ +*.tar +*.pth +*.pt +*.gz \ No newline at end of file diff --git a/README.md b/README.md index 021e5dc..86cd7f9 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ By [Zizheng Pan](https://scholar.google.com.au/citations?user=w_VMopoAAAAJ&hl=en ## News +- 28/03/2023. Code for stitching LeViTs has been released. - 27/03/2023. We release the code and checkpoints for stitching ResNets and Swin Transformers. - 22/03/2023. SN-Net was selected as a highlight at CVPR 2023!🔥 - 02/03/2023. We release the source code! Any issues are welcomed! diff --git a/stitching_levit/.gitignore b/stitching_levit/.gitignore new file mode 100755 index 0000000..b4fe8e1 --- /dev/null +++ b/stitching_levit/.gitignore @@ -0,0 +1,113 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# PyCharm +.idea + +output/ + +# PyTorch weights +*.tar +*.pth +*.pt +*.gz +Untitled.ipynb +Testing notebook.ipynb +output/* +logs/* \ No newline at end of file diff --git a/stitching_levit/LICENSE b/stitching_levit/LICENSE new file mode 100755 index 0000000..b1395e9 --- /dev/null +++ b/stitching_levit/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 - present, Facebook, Inc + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/stitching_levit/README.md b/stitching_levit/README.md new file mode 100755 index 0000000..f88aaf0 --- /dev/null +++ b/stitching_levit/README.md @@ -0,0 +1,57 @@ +# Stitchable Neural Networks 🪡 + +This directory contains the training and evaluation scripts for stitching LeViT-192/256. + + +## Requirements + +### Prepare Python Environment + +* PyTorch 1.10.1+ +* CUDA 11.1+ +* fvcore 0.1.5 + +### Prepare Pretrained Weights + +Download the pretrained weights of LeViT-192/256 from [here](https://github.com/facebookresearch/LeViT) and put them in the `pretrained/` directory. +The following commands can be helpful. + +```bash +cd pretrained/ +wget https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth +wget https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth +``` + +## Training + +To stitch LeViT-192/256 on ImageNet with 8 GPUs, run the following command: + +```bash +python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model stitch_levits \ + --data-path [path/to/imagenet] \ + --output_dir ./exp_levit_192_256 \ + --epochs 100 \ + --batch-size 128 \ + --lr 5e-5 \ + --warmup-lr 1e-7 \ + --min-lr 1e-6 +``` + +## Evaluation + +You can download our trained weights from [here](). Next, + +```bash +python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --model stitch_levits \ + --data-path [path/to/imagenet] \ + --output_dir ./eval_levit_192_256 \ + --batch-size 128 \ + --resume [path/to/checkpoint.pth] --eval +``` + +After evaluation, you can find a `stitches_res.txt` under the `output_dir` directory which contains the results for all stitches. Our evaluation results can be found at `results/stitches_res.txt`. + + +## Acknowledgement + +This code is based on [LeViT](https://github.com/facebookresearch/LeViT). We thank the authors for their released code. \ No newline at end of file diff --git a/stitching_levit/datasets.py b/stitching_levit/datasets.py new file mode 100755 index 0000000..d4217cb --- /dev/null +++ b/stitching_levit/datasets.py @@ -0,0 +1,120 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +import os +import json + +from torchvision import datasets, transforms +from torchvision.datasets.folder import ImageFolder, default_loader +import torch + +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import create_transform + + +class INatDataset(ImageFolder): + def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, + category='name', loader=default_loader): + self.transform = transform + self.loader = loader + self.target_transform = target_transform + self.year = year + # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] + path_json = os.path.join( + root, f'{"train" if train else "val"}{year}.json') + with open(path_json) as json_file: + data = json.load(json_file) + + with open(os.path.join(root, 'categories.json')) as json_file: + data_catg = json.load(json_file) + + path_json_for_targeter = os.path.join(root, f"train{year}.json") + + with open(path_json_for_targeter) as json_file: + data_for_targeter = json.load(json_file) + + targeter = {} + indexer = 0 + for elem in data_for_targeter['annotations']: + king = [] + king.append(data_catg[int(elem['category_id'])][category]) + if king[0] not in targeter.keys(): + targeter[king[0]] = indexer + indexer += 1 + self.nb_classes = len(targeter) + + self.samples = [] + for elem in data['images']: + cut = elem['file_name'].split('/') + target_current = int(cut[2]) + path_current = os.path.join(root, cut[0], cut[2], cut[3]) + + categors = data_catg[target_current] + target_current_true = targeter[categors[category]] + self.samples.append((path_current, target_current_true)) + + # __getitem__ and __len__ inherited from ImageFolder + + +def build_dataset(is_train, args): + transform = build_transform(is_train, args) + + if args.data_set == 'CIFAR': + dataset = datasets.CIFAR100( + args.data_path, train=is_train, transform=transform) + nb_classes = 100 + elif args.data_set == 'IMNET': + root = os.path.join(args.data_path, 'train' if is_train else 'val') + dataset = datasets.ImageFolder(root, transform=transform) + nb_classes = 1000 + elif args.data_set == 'FLOWERS': + root = os.path.join(args.data_path, 'train' if is_train else 'test') + dataset = datasets.ImageFolder(root, transform=transform) + if is_train: + dataset = torch.utils.data.ConcatDataset( + [dataset for _ in range(100)]) + nb_classes = 102 + elif args.data_set == 'INAT': + dataset = INatDataset(args.data_path, train=is_train, year=2018, + category=args.inat_category, transform=transform) + nb_classes = dataset.nb_classes + elif args.data_set == 'INAT19': + dataset = INatDataset(args.data_path, train=is_train, year=2019, + category=args.inat_category, transform=transform) + nb_classes = dataset.nb_classes + + return dataset, nb_classes + + +def build_transform(is_train, args): + resize_im = args.input_size > 32 + if is_train: + # this should always dispatch to transforms_imagenet_train + transform = create_transform( + input_size=args.input_size, + is_training=True, + color_jitter=args.color_jitter, + auto_augment=args.aa, + interpolation=args.train_interpolation, + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + ) + if not resize_im: + # replace RandomResizedCropAndInterpolation with + # RandomCrop + transform.transforms[0] = transforms.RandomCrop( + args.input_size, padding=4) + return transform + + t = [] + if resize_im: + size = int((256 / 224) * args.input_size) + t.append( + # to maintain same ratio w.r.t. 224 images + transforms.Resize(size, interpolation=3), + ) + t.append(transforms.CenterCrop(args.input_size)) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) + return transforms.Compose(t) diff --git a/stitching_levit/engine.py b/stitching_levit/engine.py new file mode 100755 index 0000000..3c70c83 --- /dev/null +++ b/stitching_levit/engine.py @@ -0,0 +1,180 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +""" +Train and eval functions used in main.py +""" +import math +import sys +from typing import Iterable, Optional + +import torch + +from timm.data import Mixup +from timm.utils import accuracy, ModelEma + +from losses import DistillationLoss +import utils + +def initialize_model_stitching_layer(model, mixup_fn, data_loader, device): + for samples, targets in data_loader: + samples = samples.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + if mixup_fn is not None: + samples, targets = mixup_fn(samples, targets) + + model.initialize_stitching_weights(samples) + + break + +def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + clip_grad: float = 0, + clip_mode: str = 'norm', + model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, + set_training_mode=True): + model.train(set_training_mode) + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue( + window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = 10 + + for samples, targets in metric_logger.log_every( + data_loader, print_freq, header): + samples = samples.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + if mixup_fn is not None: + samples, targets = mixup_fn(samples, targets) + + with torch.cuda.amp.autocast(): + outputs = model(samples) + loss = criterion(samples, outputs, targets) + + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + optimizer.zero_grad() + + # this attribute is added by timm on one optimizer (adahessian) + is_second_order = hasattr( + optimizer, 'is_second_order') and optimizer.is_second_order + loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, + parameters=model.parameters(), create_graph=is_second_order) + + torch.cuda.synchronize() + if model_ema is not None: + model_ema.update(model) + + metric_logger.update(loss=loss_value) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +import os +import json +from fvcore.nn import FlopCountAnalysis + +@torch.no_grad() +def evaluate_snnet(data_loader, model, device, output_dir): + # check last config: + last_cfg_id = -1 + if os.path.exists(output_dir): + with open(output_dir, 'r') as f: + for line in f.readlines(): + epoch_stat = json.loads(line.strip()) + last_cfg_id = epoch_stat['cfg_id'] + + criterion = torch.nn.CrossEntropyLoss() + + header = 'Test:' + + # switch to evaluation mode + model.eval() + + if hasattr(model, 'module'): + num_configs = model.module.num_configs + else: + num_configs = model.num_configs + + for cfg_id in range(last_cfg_id+1, num_configs): + if hasattr(model, 'module'): + model.module.reset_stitch_id(cfg_id) + else: + model.reset_stitch_id(cfg_id) + + print(f'------------- Evaluting stitch config {cfg_id}/{num_configs} -------------') + + flops = FlopCountAnalysis(model, torch.randn(1, 3, 224, 224).cuda()) + print(f'FLOPs = {round(flops.total() / 1e9, 2)}') + + metric_logger = utils.MetricLogger(delimiter=" ") + + for images, target in metric_logger.log_every(data_loader, 10, header): + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + + # compute output + with torch.cuda.amp.autocast(): + output = model(images) + loss = criterion(output, target) + + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + batch_size = images.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print('cfg_id = ' + str( + cfg_id) + ' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' + .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) + + log_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + log_stats['cfg_id'] = cfg_id + log_stats['flops'] = flops.total() + utils.save_on_master_eval_res(log_stats, output_dir) + + + +@torch.no_grad() +def evaluate(data_loader, model, device): + criterion = torch.nn.CrossEntropyLoss() + + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + + # switch to evaluation mode + model.eval() + + for images, target in metric_logger.log_every(data_loader, 10, header): + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + + # compute output + with torch.cuda.amp.autocast(): + output = model(images) + loss = criterion(output, target) + + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + batch_size = images.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' + .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) + print(output.mean().item(), output.std().item()) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/stitching_levit/hubconf.py b/stitching_levit/hubconf.py new file mode 100755 index 0000000..a1b0c5d --- /dev/null +++ b/stitching_levit/hubconf.py @@ -0,0 +1,6 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +from levit import LeViT_128S, LeViT_128, LeViT_192, LeViT_256, LeViT_384 + +dependencies = ["torch", "torchvision", "timm"] diff --git a/stitching_levit/levit.py b/stitching_levit/levit.py new file mode 100755 index 0000000..ee8a3a8 --- /dev/null +++ b/stitching_levit/levit.py @@ -0,0 +1,560 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License + +import torch +import itertools +import utils + +from timm.models.vision_transformer import trunc_normal_ +from timm.models.registry import register_model + +specification = { + 'LeViT_128S': { + 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, + 'LeViT_128': { + 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, + 'LeViT_192': { + 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, + 'LeViT_256': { + 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, + 'LeViT_384': { + 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, +} + +__all__ = [specification.keys()] + + +@register_model +def LeViT_128S(num_classes=1000, distillation=True, + pretrained=False, pretrained_cfg=None, fuse=False): + return model_factory(**specification['LeViT_128S'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def LeViT_128(num_classes=1000, distillation=True, + pretrained=False, pretrained_cfg=None, fuse=False): + return model_factory(**specification['LeViT_128'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def LeViT_192(num_classes=1000, distillation=True, + pretrained=False, pretrained_cfg=None, fuse=False): + return model_factory(**specification['LeViT_192'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def LeViT_256(num_classes=1000, distillation=True, + pretrained=False, pretrained_cfg=None, fuse=False): + return model_factory(**specification['LeViT_256'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def LeViT_384(num_classes=1000, distillation=True, + pretrained=False, pretrained_cfg=None, fuse=False): + return model_factory(**specification['LeViT_384'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +FLOPS_COUNTER = 0 + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + global FLOPS_COUNTER + output_points = ((resolution + 2 * pad - dilation * + (ks - 1) - 1) // stride + 1)**2 + FLOPS_COUNTER += a * b * output_points * (ks**2) // groups + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class Linear_BN(torch.nn.Sequential): + def __init__(self, a, b, bn_weight_init=1, resolution=-100000): + super().__init__() + self.add_module('c', torch.nn.Linear(a, b, bias=False)) + bn = torch.nn.BatchNorm1d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + global FLOPS_COUNTER + output_points = resolution**2 + FLOPS_COUNTER += a * b * output_points + + @torch.no_grad() + def fuse(self): + l, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = l.weight * w[:, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + def forward(self, x): + l, bn = self._modules.values() + x = l(x) + return bn(x.flatten(0, 1)).reshape_as(x) + + +class BN_Linear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + l = torch.nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + torch.nn.init.constant_(l.bias, 0) + self.add_module('l', l) + global FLOPS_COUNTER + FLOPS_COUNTER += a * b + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps)**0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def b16(n, activation, resolution=224): + return torch.nn.Sequential( + Conv2d_BN(3, n // 8, 3, 2, 1, resolution=resolution), + activation(), + Conv2d_BN(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + Conv2d_BN(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(torch.nn.Module): + def __init__(self, m, drop): + super().__init__() + self.is_subsample_layer = False + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand(x.size(0), 1, 1, + device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + activation=None, + resolution=14): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = Linear_BN(dim, h, resolution=resolution) + self.proj = torch.nn.Sequential(activation(), Linear_BN( + self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N)) + + global FLOPS_COUNTER + #queries * keys + FLOPS_COUNTER += num_heads * (resolution**4) * key_dim + # softmax + FLOPS_COUNTER += num_heads * (resolution**4) + #attention * v + FLOPS_COUNTER += num_heads * self.d * (resolution**4) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, self.num_heads, - + 1).split([self.key_dim, self.key_dim, self.d], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ( + (q @ k.transpose(-2, -1)) * self.scale + + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class Subsample(torch.nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + self.is_subsample_layer = True + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[ + :, ::self.stride, ::self.stride].reshape(B, -1, C) + return x + + +class AttentionSubsample(torch.nn.Module): + def __init__(self, in_dim, out_dim, key_dim, num_heads=8, + attn_ratio=2, + activation=None, + stride=2, + resolution=14, resolution_=7): + super().__init__() + self.is_subsample_layer = True + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_**2 + h = self.dh + nh_kd + self.kv = Linear_BN(in_dim, h, resolution=resolution) + + self.q = torch.nn.Sequential( + Subsample(stride, resolution), + Linear_BN(in_dim, nh_kd, resolution=resolution_)) + self.proj = torch.nn.Sequential(activation(), Linear_BN( + self.dh, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product( + range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N_, N)) + + global FLOPS_COUNTER + #queries * keys + FLOPS_COUNTER += num_heads * \ + (resolution**2) * (resolution_**2) * key_dim + # softmax + FLOPS_COUNTER += num_heads * (resolution**2) * (resolution_**2) + #attention * v + FLOPS_COUNTER += num_heads * \ + (resolution**2) * (resolution_**2) * self.d + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, - + 1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_2, self.num_heads, + self.key_dim).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + \ + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = self.proj(x) + return x + + +class LeViT(torch.nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=[192], + key_dim=[64], + depth=[12], + num_heads=[3], + attn_ratio=[2], + mlp_ratio=[2], + hybrid_backbone=None, + down_ops=[], + attention_activation=torch.nn.Hardswish, + mlp_activation=torch.nn.Hardswish, + distillation=True, + drop_path=0): + super().__init__() + global FLOPS_COUNTER + + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.distillation = distillation + + self.patch_embed = hybrid_backbone + self.depth = depth + + self.blocks = [] + down_ops.append(['']) + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual(Attention( + ed, kd, nh, + attn_ratio=ar, + activation=attention_activation, + resolution=resolution, + ), drop_path)) + if mr > 0: # ffn/mlp + h = int(ed * mr) + self.blocks.append( + Residual(torch.nn.Sequential( + Linear_BN(ed, h, resolution=resolution), + mlp_activation(), + Linear_BN(h, ed, bn_weight_init=0, + resolution=resolution), + ), drop_path)) + if do[0] == 'Subsample': + #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], + attn_ratio=do[3], + activation=attention_activation, + stride=do[5], + resolution=resolution, + resolution_=resolution_)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(torch.nn.Sequential( + Linear_BN(embed_dim[i + 1], h, + resolution=resolution), + mlp_activation(), + Linear_BN( + h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path)) + self.blocks = torch.nn.Sequential(*self.blocks) + + # find donwsampling layers + self.subsample_ids = [] + for i, mod in enumerate(self.blocks): + if mod.is_subsample_layer: + self.subsample_ids.append(i) + + # Classifier head + self.head = BN_Linear( + embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + if distillation: + self.head_dist = BN_Linear( + embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + self.FLOPS = FLOPS_COUNTER + FLOPS_COUNTER = 0 + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def extract_block_features(self, x): + x = self.patch_embed(x) + x = x.flatten(2).transpose(1, 2) + + features = {} + subsample_ids = self.subsample_ids + [len(self.blocks)] + start_idx = -2 + + for idx, sub_idx in enumerate(subsample_ids): + stage_feats = {} + stage_feats[-1] = x.detach() + for i, blk in enumerate(self.blocks[start_idx+2:sub_idx]): + x = blk(x) + if i % 2 == 1: # only save block outputs + stage_feats[i//2] = x.detach() + features[idx] = stage_feats + start_idx = sub_idx + x = self.blocks[sub_idx:sub_idx + 2](x) + + return features + + def forward_until(self, x, stage_id, blk_id): + x = self.patch_embed(x) + x = x.flatten(2).transpose(1, 2) + + idx = stage_id * 10 + (blk_id + 1) * 2 + + for blk in self.blocks[:idx]: + x = blk(x) + return x + + def forward_from(self, x, stage_id, blk_id): + idx = stage_id * 10 + blk_id * 2 + + for blk in self.blocks[idx:]: + x = blk(x) + + x = x.mean(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + def forward(self, x): + x = self.patch_embed(x) + x = x.flatten(2).transpose(1, 2) + x = self.blocks(x) + x = x.mean(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + +def model_factory(C, D, X, N, drop_path, weights, + num_classes, distillation, pretrained, fuse): + embed_dim = [int(x) for x in C.split('_')] + num_heads = [int(x) for x in N.split('_')] + depth = [int(x) for x in X.split('_')] + act = torch.nn.Hardswish + model = LeViT( + patch_size=16, + embed_dim=embed_dim, + num_heads=num_heads, + key_dim=[D] * 3, + depth=depth, + attn_ratio=[2, 2, 2], + mlp_ratio=[2, 2, 2], + down_ops=[ + #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ['Subsample', D, embed_dim[0] // D, 4, 2, 2], + ['Subsample', D, embed_dim[1] // D, 4, 2, 2], + ], + attention_activation=act, + mlp_activation=act, + hybrid_backbone=b16(embed_dim[0], activation=act), + num_classes=num_classes, + drop_path=drop_path, + distillation=distillation + ) + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + weights, map_location='cpu') + model.load_state_dict(checkpoint['model']) + if fuse: + utils.replace_batchnorm(model) + + return model + + +if __name__ == '__main__': + for name in specification: + net = globals()[name](fuse=True, pretrained=True) + net.eval() + net(torch.randn(4, 3, 224, 224)) + print(name, + net.FLOPS, 'FLOPs', + sum(p.numel() for p in net.parameters() if p.requires_grad), 'parameters') diff --git a/stitching_levit/levit_c.py b/stitching_levit/levit_c.py new file mode 100755 index 0000000..e7dc296 --- /dev/null +++ b/stitching_levit/levit_c.py @@ -0,0 +1,457 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License + +import torch +import itertools +import utils + +from timm.models.vision_transformer import trunc_normal_ +from timm.models.registry import register_model + +specification = { + 'LeViT_c_128S': { + 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, + 'LeViT_c_128': { + 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, + 'LeViT_c_192': { + 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, + 'LeViT_c_256': { + 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, + 'LeViT_c_384': { + 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, + 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, +} + +__all__ = [specification.keys()] + + +@register_model +def LeViT_c_128S(num_classes=1000, distillation=True, + pretrained=False, pretrained_cfg=None, fuse=False): + return model_factory(**specification['LeViT_c_128S'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def LeViT_c_128(num_classes=1000, distillation=True, + pretrained=False, pretrained_cfg=None, fuse=False): + return model_factory(**specification['LeViT_c_128'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def LeViT_c_192(num_classes=1000, distillation=True, + pretrained=False, pretrained_cfg=None, fuse=False): + return model_factory(**specification['LeViT_c_192'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def LeViT_c_256(num_classes=1000, distillation=True, + pretrained=False, pretrained_cfg=None, fuse=False): + return model_factory(**specification['LeViT_c_256'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +@register_model +def LeViT_c_384(num_classes=1000, distillation=True, + pretrained=False, pretrained_cfg=None, fuse=False): + return model_factory(**specification['LeViT_c_384'], num_classes=num_classes, + distillation=distillation, pretrained=pretrained, fuse=fuse) + + +FLOPS_COUNTER = 0 + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + global FLOPS_COUNTER + output_points = ((resolution + 2 * pad - dilation * + (ks - 1) - 1) // stride + 1)**2 + FLOPS_COUNTER += a * b * output_points * (ks**2) // groups + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class BN_Linear(torch.nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', torch.nn.BatchNorm1d(a)) + l = torch.nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + torch.nn.init.constant_(l.bias, 0) + self.add_module('l', l) + global FLOPS_COUNTER + FLOPS_COUNTER += a * b + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + b = bn.bias - self.bn.running_mean * \ + self.bn.weight / (bn.running_var + bn.eps)**0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = torch.nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def b16(n, activation, resolution=224): + return torch.nn.Sequential( + Conv2d_BN(3, n // 8, 3, 2, 1, resolution=resolution), + activation(), + Conv2d_BN(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + Conv2d_BN(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(torch.nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1, + device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + activation=None, + resolution=14): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + self.qkv = Conv2d_BN(dim, h, resolution=resolution) + self.proj = torch.nn.Sequential(activation(), Conv2d_BN( + self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N)) + + global FLOPS_COUNTER + #queries * keys + FLOPS_COUNTER += num_heads * (resolution**4) * key_dim + # softmax + FLOPS_COUNTER += num_heads * (resolution**4) + #attention * v + FLOPS_COUNTER += num_heads * self.d * (resolution**4) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,C,H,W) + B, C, H, W = x.shape + q, k, v = self.qkv(x).view( + B, self.num_heads, -1, H * W + ).split([self.key_dim, self.key_dim, self.d], dim=2) + attn = ( + (q.transpose(-2, -1) @ k) * self.scale + + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + ) + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + x = self.proj(x) + return x + + +class AttentionSubsample(torch.nn.Module): + def __init__(self, in_dim, out_dim, key_dim, num_heads=8, + attn_ratio=2, + activation=None, + stride=2, + resolution=14, resolution_=7): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_**2 + h = self.dh + nh_kd + self.kv = Conv2d_BN(in_dim, h, resolution=resolution) + self.q = torch.nn.Sequential( + torch.nn.AvgPool2d(1, stride, 0), + Conv2d_BN(in_dim, nh_kd, resolution=resolution_)) + self.proj = torch.nn.Sequential( + activation(), Conv2d_BN(self.d * num_heads, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product( + range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N_, N)) + + global FLOPS_COUNTER + #queries * keys + FLOPS_COUNTER += num_heads * \ + (resolution**2) * (resolution_**2) * key_dim + # softmax + FLOPS_COUNTER += num_heads * (resolution**2) * (resolution_**2) + #attention * v + FLOPS_COUNTER += num_heads * \ + (resolution**2) * (resolution_**2) * self.d + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * + W).split([self.key_dim, self.d], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + \ + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape( + B, -1, self.resolution_, self.resolution_) + x = self.proj(x) + return x + + +class LeViT(torch.nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=[192], + key_dim=[64], + depth=[12], + num_heads=[3], + attn_ratio=[2], + mlp_ratio=[2], + hybrid_backbone=None, + down_ops=[], + attention_activation=torch.nn.Hardswish, + mlp_activation=torch.nn.Hardswish, + distillation=True, + drop_path=0): + super().__init__() + global FLOPS_COUNTER + + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + self.distillation = distillation + + self.patch_embed = hybrid_backbone + + self.blocks = [] + down_ops.append(['']) + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual(Attention( + ed, kd, nh, + attn_ratio=ar, + activation=attention_activation, + resolution=resolution, + ), drop_path)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(torch.nn.Sequential( + Conv2d_BN(ed, h, resolution=resolution), + mlp_activation(), + Conv2d_BN(h, ed, bn_weight_init=0, + resolution=resolution), + ), drop_path)) + if do[0] == 'Subsample': + #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], + attn_ratio=do[3], + activation=attention_activation, + stride=do[5], + resolution=resolution, + resolution_=resolution_)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(torch.nn.Sequential( + Conv2d_BN(embed_dim[i + 1], h, + resolution=resolution), + mlp_activation(), + Conv2d_BN( + h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path)) + self.blocks = torch.nn.Sequential(*self.blocks) + + # Classifier head + self.head = BN_Linear( + embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + if distillation: + self.head_dist = BN_Linear( + embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + self.FLOPS = FLOPS_COUNTER + FLOPS_COUNTER = 0 + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def forward(self, x): + x = self.patch_embed(x) + x = self.blocks(x) + x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) + if self.distillation: + x = self.head(x), self.head_dist(x) + if not self.training: + x = (x[0] + x[1]) / 2 + else: + x = self.head(x) + return x + + +def model_factory(C, D, X, N, drop_path, weights, + num_classes, distillation, pretrained, fuse): + embed_dim = [int(x) for x in C.split('_')] + num_heads = [int(x) for x in N.split('_')] + depth = [int(x) for x in X.split('_')] + act = torch.nn.Hardswish + model = LeViT( + patch_size=16, + embed_dim=embed_dim, + num_heads=num_heads, + key_dim=[D] * 3, + depth=depth, + attn_ratio=[2, 2, 2], + mlp_ratio=[2, 2, 2], + down_ops=[ + #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ['Subsample', D, embed_dim[0] // D, 4, 2, 2], + ['Subsample', D, embed_dim[1] // D, 4, 2, 2], + ], + attention_activation=act, + mlp_activation=act, + hybrid_backbone=b16(embed_dim[0], activation=act), + num_classes=num_classes, + drop_path=drop_path, + distillation=distillation + ) + if pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + weights, map_location='cpu') + d = checkpoint['model'] + D = model.state_dict() + for k in d.keys(): + if D[k].shape != d[k].shape: + d[k] = d[k][:, :, None, None] + model.load_state_dict(d) + if fuse: + utils.replace_batchnorm(model) + + return model + + +if __name__ == '__main__': + for name in specification: + net = globals()[name](fuse=True, pretrained=True) + net.eval() + net(torch.randn(4, 3, 224, 224)) + print(name, + net.FLOPS, 'FLOPs', + sum(p.numel() for p in net.parameters() if p.requires_grad), 'parameters') diff --git a/stitching_levit/losses.py b/stitching_levit/losses.py new file mode 100755 index 0000000..100e651 --- /dev/null +++ b/stitching_levit/losses.py @@ -0,0 +1,66 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +""" +Implements the knowledge distillation loss +""" +import torch +from torch.nn import functional as F + + +class DistillationLoss(torch.nn.Module): + """ + This module wraps a standard criterion and adds an extra knowledge distillation loss by + taking a teacher model prediction and using it as additional supervision. + """ + + def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, + distillation_type: str, alpha: float, tau: float): + super().__init__() + self.base_criterion = base_criterion + self.teacher_model = teacher_model + assert distillation_type in ['none', 'soft', 'hard'] + self.distillation_type = distillation_type + self.alpha = alpha + self.tau = tau + + def forward(self, inputs, outputs, labels): + """ + Args: + inputs: The original inputs that are feed to the teacher model + outputs: the outputs of the model to be trained. It is expected to be + either a Tensor, or a Tuple[Tensor, Tensor], with the original output + in the first position and the distillation predictions as the second output + labels: the labels for the base criterion + """ + outputs_kd = None + if not isinstance(outputs, torch.Tensor): + # assume that the model outputs a tuple of [outputs, outputs_kd] + outputs, outputs_kd = outputs + base_loss = self.base_criterion(outputs, labels) + if self.distillation_type == 'none': + return base_loss + + if outputs_kd is None: + raise ValueError("When knowledge distillation is enabled, the model is " + "expected to return a Tuple[Tensor, Tensor] with the output of the " + "class_token and the dist_token") + # don't backprop throught the teacher + with torch.no_grad(): + teacher_outputs = self.teacher_model(inputs) + + if self.distillation_type == 'soft': + T = self.tau + # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 + # with slight modifications + distillation_loss = F.kl_div( + F.log_softmax(outputs_kd / T, dim=1), + F.log_softmax(teacher_outputs / T, dim=1), + reduction='sum', + log_target=True + ) * (T * T) / outputs_kd.numel() + elif self.distillation_type == 'hard': + distillation_loss = F.cross_entropy( + outputs_kd, teacher_outputs.argmax(dim=1)) + + loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha + return loss diff --git a/stitching_levit/main.py b/stitching_levit/main.py new file mode 100755 index 0000000..f7e0d8f --- /dev/null +++ b/stitching_levit/main.py @@ -0,0 +1,416 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +import argparse +import datetime +import numpy as np +import time +import torch +import torch.backends.cudnn as cudnn +import json +import os + +from pathlib import Path + +from timm.data import Mixup +from timm.models import create_model +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.scheduler import create_scheduler +from timm.optim import create_optimizer +from timm.utils import NativeScaler, get_state_dict, ModelEma + +from datasets import build_dataset +from engine import train_one_epoch, evaluate, initialize_model_stitching_layer, evaluate_snnet +from losses import DistillationLoss +from samplers import RASampler +import levit +import levit_c +import utils +from snnet import SNNet + +def get_args_parser(): + parser = argparse.ArgumentParser( + 'LeViT training and evaluation script', add_help=False) + parser.add_argument('--batch-size', default=256, type=int) + parser.add_argument('--epochs', default=1000, type=int) + + # Model parameters + parser.add_argument('--model', default='LeViT_256', type=str, metavar='MODEL', + help='Name of model to train') + parser.add_argument('--input-size', default=224, + type=int, help='images input size') + + parser.add_argument('--model-ema', action='store_true') + parser.add_argument( + '--no-model-ema', action='store_false', dest='model_ema') + parser.set_defaults(model_ema=True) + parser.add_argument('--model-ema-decay', type=float, + default=0.99996, help='') + parser.add_argument('--model-ema-force-cpu', + action='store_true', default=False, help='') + + # Optimizer parameters + parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "adamw"') + parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') + parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') + parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--clip-mode', type=str, default='agc', + help='Gradient clipping mode. One of ("norm", "value", "agc")') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') + parser.add_argument('--weight-decay', type=float, default=0.025, + help='weight decay (default: 0.025)') + # Learning rate schedule parameters + parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "cosine"') + parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', + help='learning rate (default: 5e-4)') + parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') + parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', + help='learning rate noise limit percent (default: 0.67)') + parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', + help='learning rate noise std-dev (default: 1.0)') + parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', + help='warmup learning rate (default: 1e-6)') + parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') + + parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', + help='epoch interval to decay LR') + parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', + help='epochs to warmup LR, if scheduler supports') + parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') + parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', + help='patience epochs for Plateau LR scheduler (default: 10') + parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') + + # Augmentation parameters + parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', + help='Color jitter factor (default: 0.4)') + parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". " + \ + "(default: rand-m9-mstd0.5-inc1)'), + parser.add_argument('--smoothing', type=float, default=0.1, + help='Label smoothing (default: 0.1)') + parser.add_argument('--train-interpolation', type=str, default='bicubic', + help='Training interpolation (random, bilinear, bicubic default: "bicubic")') + + parser.add_argument('--repeated-aug', action='store_true') + parser.add_argument('--no-repeated-aug', + action='store_false', dest='repeated_aug') + parser.set_defaults(repeated_aug=True) + + # * Random Erase params + parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', + help='Random erase prob (default: 0.25)') + parser.add_argument('--remode', type=str, default='pixel', + help='Random erase mode (default: "pixel")') + parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') + parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') + + # * Mixup params + parser.add_argument('--mixup', type=float, default=0.8, + help='mixup alpha, mixup enabled if > 0. (default: 0.8)') + parser.add_argument('--cutmix', type=float, default=1.0, + help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') + parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') + parser.add_argument('--mixup-prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') + parser.add_argument('--mixup-switch-prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') + parser.add_argument('--mixup-mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + + # Distillation parameters + parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', + help='Name of teacher model to train (default: "regnety_160"') + parser.add_argument('--teacher-path', type=str, + default='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth') + parser.add_argument('--distillation-type', default='hard', + choices=['none', 'soft', 'hard'], type=str, help="") + parser.add_argument('--distillation-alpha', + default=0.5, type=float, help="") + parser.add_argument('--distillation-tau', default=1.0, type=float, help="") + + # * Finetuning params + parser.add_argument('--finetune', default='', + help='finetune from checkpoint') + + # Dataset parameters + parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, + help='dataset path') + parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], + type=str, help='Image Net dataset path') + parser.add_argument('--inat-category', default='name', + choices=['kingdom', 'phylum', 'class', 'order', + 'supercategory', 'family', 'genus', 'name'], + type=str, help='semantic granularity') + + parser.add_argument('--output_dir', default='', + help='path where to save, empty for no saving') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', + help='Perform evaluation only') + parser.add_argument('--dist-eval', action='store_true', + default=False, help='Enabling distributed evaluation') + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--pin-mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', + help='') + parser.set_defaults(pin_mem=True) + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + return parser + + +def main(args): + utils.init_distributed_mode(args) + + print(args) + + if args.distillation_type != 'none' and args.finetune and not args.eval: + raise NotImplementedError( + "Finetuning with distillation not yet supported") + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + # random.seed(seed) + + cudnn.benchmark = True + + dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) + dataset_val, _ = build_dataset(is_train=False, args=args) + + if True: # args.distributed: + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + if args.repeated_aug: + sampler_train = RASampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + else: + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=int(1.5 * args.batch_size), + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + mixup_fn = None + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + mixup_fn = Mixup( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + label_smoothing=args.smoothing, num_classes=args.nb_classes) + + model_cards = [ + ('LeViT_192', 'pretrained/LeViT-192-92712e41.pth'), + ('LeViT_256', 'pretrained/LeViT-256-13b5763e.pth') + ] + anchors = [] + for name, checkpoint_path in model_cards: + model = create_model(name) + checkpoint = torch.load(checkpoint_path, map_location='cpu') + model.load_state_dict(checkpoint["model"]) + anchors.append(model) + + model = SNNet(anchors) + + if args.finetune: + if args.finetune.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.finetune, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.finetune, map_location='cpu') + + checkpoint_model = checkpoint['model'] + state_dict = model.state_dict() + for k in ['head.weight', 'head.bias', + 'head_dist.weight', 'head_dist.bias']: + if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + model.load_state_dict(checkpoint_model, strict=False) + + model.to(device) + + temp_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=100, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + # solve by least square + initialize_model_stitching_layer(model, mixup_fn, temp_loader_train, device) + print('Stitching Layer Initialized') + del temp_loader_train + + model_ema = None + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], find_unused_parameters=True) + model_without_ddp = model.module + n_parameters = sum(p.numel() + for p in model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + + linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 + args.lr = linear_scaled_lr + + optimizer = create_optimizer(args, model_without_ddp) + + loss_scaler = NativeScaler() + + lr_scheduler, _ = create_scheduler(args, optimizer) + + criterion = LabelSmoothingCrossEntropy() + + if args.mixup > 0.: + # smoothing is handled with mixup label transform + criterion = SoftTargetCrossEntropy() + elif args.smoothing: + criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) + else: + criterion = torch.nn.CrossEntropyLoss() + + teacher_model = None + if args.distillation_type != 'none': + assert args.teacher_path, 'need to specify teacher-path when using distillation' + print(f"Creating teacher model: {args.teacher_model}") + teacher_model = create_model( + args.teacher_model, + pretrained=False, + num_classes=args.nb_classes, + global_pool='avg', + ) + if args.teacher_path.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.teacher_path, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.teacher_path, map_location='cpu') + teacher_model.load_state_dict(checkpoint['model']) + teacher_model.to(device) + teacher_model.eval() + + # wrap the criterion in our custom DistillationLoss, which + # just dispatches to the original criterion if args.distillation_type is + # 'none' + criterion = DistillationLoss( + criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau + ) + + output_dir = Path(args.output_dir) + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + if args.eval: + evaluate_snnet(data_loader_val, model, device, os.path.join(args.output_dir, 'stitches_res.txt')) + return + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + max_accuracy = 0.0 + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + args.clip_grad, args.clip_mode, model_ema, mixup_fn, + set_training_mode=args.finetune == '' # keep in eval mode during finetuning + ) + + lr_scheduler.step(epoch) + if args.output_dir: + checkpoint_paths = [output_dir / 'checkpoint.pth'] + for checkpoint_path in checkpoint_paths: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + # 'model_ema': get_state_dict(model_ema), + 'scaler': loss_scaler.state_dict(), + 'args': args, + }, checkpoint_path) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + evaluate_snnet(data_loader_val, model, device, os.path.join(args.output_dir, 'stitches_res.txt')) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + 'LeViT training and evaluation script', parents=[get_args_parser()]) + args = parser.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/stitching_levit/pretrained/readme.txt b/stitching_levit/pretrained/readme.txt new file mode 100644 index 0000000..a6df56e --- /dev/null +++ b/stitching_levit/pretrained/readme.txt @@ -0,0 +1 @@ +place your pretrained deit checkpoints here \ No newline at end of file diff --git a/stitching_levit/results/stitches_res.txt b/stitching_levit/results/stitches_res.txt new file mode 100644 index 0000000..f0e0837 --- /dev/null +++ b/stitching_levit/results/stitches_res.txt @@ -0,0 +1,8 @@ +{"loss": 0.8239520951762966, "acc1": 79.63800263183593, "acc5": 94.53000249023438, "cfg_id": 0, "flops": 661663808} +{"loss": 0.7729868526602613, "acc1": 81.22200236328125, "acc5": 95.21800245117187, "cfg_id": 1, "flops": 1133412864} +{"loss": 0.8509775411294794, "acc1": 79.35200238769531, "acc5": 94.2600024609375, "cfg_id": 2, "flops": 1020918272} +{"loss": 0.792813112937856, "acc1": 80.72600254882812, "acc5": 94.92400262695313, "cfg_id": 3, "flops": 977741824} +{"loss": 0.7968292268585428, "acc1": 80.46600247558594, "acc5": 94.94800263671875, "cfg_id": 4, "flops": 934565376} +{"loss": 0.8009902323913757, "acc1": 80.44200267578125, "acc5": 94.81200263671874, "cfg_id": 5, "flops": 812300576} +{"loss": 0.8108080803377419, "acc1": 80.08000273925781, "acc5": 94.62600272460938, "cfg_id": 6, "flops": 791247040} +{"loss": 0.8195572130068052, "acc1": 79.68200244140625, "acc5": 94.5700024609375, "cfg_id": 7, "flops": 770193504} diff --git a/stitching_levit/run_with_submitit.py b/stitching_levit/run_with_submitit.py new file mode 100755 index 0000000..b53f25e --- /dev/null +++ b/stitching_levit/run_with_submitit.py @@ -0,0 +1,137 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +""" +A script to run multinode training with submitit. +""" +import argparse +import os +import uuid +from pathlib import Path + +import main as classification +import submitit + + +def parse_args(): + classification_parser = classification.get_args_parser() + parser = argparse.ArgumentParser( + "Submitit for DeiT", parents=[classification_parser]) + parser.add_argument("--ngpus", default=8, type=int, + help="Number of gpus to request on each node") + parser.add_argument("--nodes", default=1, type=int, + help="Number of nodes to request") + parser.add_argument("--timeout", default=4320, + type=int, help="Duration of the job") + parser.add_argument("--job_dir", default="", type=str, + help="Job dir. Leave empty for automatic.") + + parser.add_argument("--partition", default="learnfair", + type=str, help="Partition where to submit") + parser.add_argument("--use_volta32", action='store_true', + help="Big models? Use this") + parser.add_argument('--comment', default="", type=str, + help='Comment to pass to scheduler, e.g. priority message') + return parser.parse_args() + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + if Path("/checkpoint/").is_dir(): + p = Path(f"/checkpoint/{user}/experiments") + p.mkdir(exist_ok=True) + return p + raise RuntimeError("No shared folder available") + + +def get_init_file(): + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + import main as classification + + self._setup_gpu_args() + classification.main(self.args) + + def checkpoint(self): + import os + import submitit + + self.args.dist_url = get_init_file().as_uri() + checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") + if os.path.exists(checkpoint_file): + self.args.resume = checkpoint_file + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self): + import submitit + from pathlib import Path + + job_env = submitit.JobEnvironment() + self.args.output_dir = Path( + str(self.args.output_dir).replace("%j", str(job_env.job_id))) + self.args.gpu = job_env.local_rank + self.args.rank = job_env.global_rank + self.args.world_size = job_env.num_tasks + print( + f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(): + args = parse_args() + if args.job_dir == "": + args.job_dir=f'{args.model}_{args.epochs}' + + # Note that the folder will depend on the job_id, to easily track + # experiments + executor = submitit.AutoExecutor( + folder=args.job_dir, slurm_max_num_timeout=30) + + num_gpus_per_node = args.ngpus + nodes = args.nodes + timeout_min = args.timeout + + partition = args.partition + kwargs = {} + if args.use_volta32: + kwargs['slurm_constraint'] = 'volta32gb' + if args.comment: + kwargs['slurm_comment'] = args.comment + + executor.update_parameters( + mem_gb=40 * num_gpus_per_node, + gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=10, + nodes=nodes, + timeout_min=timeout_min, # max is 60 * 72 + # Below are cluster dependent parameters + slurm_partition=partition, + slurm_signal_delay_s=120, + **kwargs + ) + + executor.update_parameters(name=f'LeViT_{args.model}') + + args.dist_url = get_init_file().as_uri() + args.output_dir = args.job_dir + + trainer = Trainer(args) + job = executor.submit(trainer) + + print("Submitted job_id:", job.job_id) + + +if __name__ == "__main__": + main() diff --git a/stitching_levit/samplers.py b/stitching_levit/samplers.py new file mode 100755 index 0000000..e76d6d3 --- /dev/null +++ b/stitching_levit/samplers.py @@ -0,0 +1,63 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +import torch +import torch.distributed as dist +import math + + +class RASampler(torch.utils.data.Sampler): + """Sampler that restricts data loading to a subset of the dataset for distributed, + with repeated augmentation. + It ensures that different each augmented version of a sample will be visible to a + different process (GPU) + Heavily based on torch.utils.data.DistributedSampler + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError( + "Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError( + "Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int( + math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) + self.num_selected_samples = int(math.floor( + len(self.dataset) // 256 * 256 / self.num_replicas)) + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + if self.shuffle: + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices = [ele for ele in indices for i in range(3)] + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices[:self.num_selected_samples]) + + def __len__(self): + return self.num_selected_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/stitching_levit/snnet.py b/stitching_levit/snnet.py new file mode 100755 index 0000000..fe5ac36 --- /dev/null +++ b/stitching_levit/snnet.py @@ -0,0 +1,136 @@ +import os.path + +import torch.nn as nn +import torch +from collections import defaultdict +from timm.models import create_model +from utils import get_stitch_configs, ps_inv +from timm.models.registry import register_model +import numpy as np + + +class StitchingLayer(nn.Module): + def __init__(self, in_features=None, out_features=None): + super().__init__() + self.transform = nn.Linear(in_features, out_features) + + def init_stitch_weights_bias(self, weight, bias): + self.transform.weight.data.copy_(weight) + self.transform.bias.data.copy_(bias) + + def forward(self, x): + x = self.transform(x) + return x + + +class SNNet(nn.Module): + ''' + Stitchable Neural Networks + ''' + + def __init__(self, anchors): + super(SNNet, self).__init__() + + self.anchors = nn.ModuleList(anchors) + stage_depths = [mod.depth for mod in self.anchors] + + total_configs = [] + self.num_stitches = [] + self.stitch_layers = nn.ModuleList() + self.stitching_map_id = {} + + for i in range(len(self.anchors)): + total_configs.append({ + 'comb_id': [i], + 'stitch_cfgs': [], + 'stitch_layers': [] + }) + + for i in range(3): + if i == 2: + break + cur_depths = [stage_depths[mod_id][i] for mod_id in range(len(self.anchors))] + stage_configs, stage_stitches = get_stitch_configs(cur_depths, i) + self.num_stitches.append(stage_stitches) + total_configs += stage_configs + stage_stitching_layers = nn.ModuleList() + for j, (num_s, comb) in enumerate(stage_stitches): + front, end = comb + stage_stitching_layers.append(nn.ModuleList( + [StitchingLayer(self.anchors[front].embed_dim[i], self.anchors[end].embed_dim[i]) for _ in range(num_s)])) + self.stitching_map_id[f'{i}-{front}-{end}'] = j + self.stitch_layers.append(stage_stitching_layers) + + self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)} + self.num_configs = len(total_configs) + self.stitch_config_id = 0 + + def reset_stitch_id(self, stitch_config_id): + self.stitch_config_id = stitch_config_id + + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + + def initialize_stitching_weights(self, x): + vit_features = [] + with torch.no_grad(): + for mod in self.anchors: + vit_features.append(mod.extract_block_features(x)) + + for stage_id in range(3): + if stage_id == 2: + break + stage_stitches = self.num_stitches[stage_id] + + for j, (num_s, comb) in enumerate(stage_stitches): + front, end = comb + stitching_dicts = defaultdict(set) + for id, config in self.stitch_configs.items(): + if config['comb_id'] == comb and stage_id == config['stage_id']: + stitching_dicts[config['stitch_layers'][0]].add(config['stitch_cfgs'][0]) + + for stitch_layer_id, stitch_positions in stitching_dicts.items(): + weight_candidates = [] + bias_candidates = [] + for front_id, end_id in stitch_positions: + front_blk_feat = vit_features[front][stage_id][front_id] + end_blk_feat = vit_features[end][stage_id][end_id - 1] + w, b = ps_inv(front_blk_feat, end_blk_feat) + weight_candidates.append(w) + bias_candidates.append(b) + weights = torch.stack(weight_candidates).mean(dim=0) + bias = torch.stack(bias_candidates).mean(dim=0) + + self.stitch_layers[stage_id][j][stitch_layer_id].init_stitch_weights_bias(weights, bias) + print(f'Initialized Stitching Model {front} to Model {end}, Stage {stage_id}, Layer {stitch_layer_id}') + + + def forward(self, x): + if self.training: + stitch_cfg_id = np.random.randint(0, self.num_configs) + else: + stitch_cfg_id = self.stitch_config_id + + comb_id = self.stitch_configs[stitch_cfg_id]['comb_id'] + if len(comb_id) == 1: + return self.anchors[comb_id[0]](x) + + stitch_cfgs = self.stitch_configs[stitch_cfg_id]['stitch_cfgs'] + stitch_stage_id = self.stitch_configs[stitch_cfg_id]['stage_id'] + stitch_layer_ids = self.stitch_configs[stitch_cfg_id]['stitch_layers'] + + cfg = stitch_cfgs[0] + + x = self.anchors[comb_id[0]].forward_until(x, stage_id=stitch_stage_id, blk_id=cfg[0]) + + sl_id = stitch_layer_ids[0] + key = f'{stitch_stage_id}-{comb_id[0]}-{comb_id[1]}' + stitch_projection_id = self.stitching_map_id[key] + x = self.stitch_layers[stitch_stage_id][stitch_projection_id][sl_id](x) + + x = self.anchors[comb_id[1]].forward_from(x, stage_id=stitch_stage_id, blk_id=cfg[1]) + + return x diff --git a/stitching_levit/utils.py b/stitching_levit/utils.py new file mode 100755 index 0000000..5082199 --- /dev/null +++ b/stitching_levit/utils.py @@ -0,0 +1,402 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import io +import os +import time +from collections import defaultdict, deque +import datetime + +import torch +import torch.distributed as dist +import torch.nn as nn +import json + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], + dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def _load_checkpoint_for_ema(model_ema, checkpoint): + """ + Workaround for ModelEma._load_checkpoint to accept an already-loaded object + """ + mem_file = io.BytesIO() + torch.save(checkpoint, mem_file) + mem_file.seek(0) + model_ema._load_checkpoint(mem_file) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def replace_batchnorm(net): + for child_name, child in net.named_children(): + if hasattr(child, 'fuse'): + setattr(net, child_name, child.fuse()) + elif isinstance(child, torch.nn.Conv2d): + child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0))) + elif isinstance(child, torch.nn.BatchNorm2d): + setattr(net, child_name, torch.nn.Identity()) + else: + replace_batchnorm(child) + + +def replace_layernorm(net): + import apex + for child_name, child in net.named_children(): + if isinstance(child, torch.nn.LayerNorm): + setattr(net, child_name, apex.normalization.FusedLayerNorm( + child.weight.size(0))) + else: + replace_layernorm(child) + + +def unpaired_stitching(front_depth=12, end_depth=24): + num_stitches = front_depth + + block_ids = torch.tensor(list(range(front_depth))) + block_ids = block_ids[None, None, :].float() + end_mapping_ids = torch.nn.functional.interpolate(block_ids, end_depth) + end_mapping_ids = end_mapping_ids.squeeze().long().tolist() + front_mapping_ids = block_ids.squeeze().long().tolist() + + stitch_cfgs = [] + for idx in front_mapping_ids: + for i, e_idx in enumerate(end_mapping_ids): + if idx != e_idx or idx >= i: + continue + else: + stitch_cfgs.append((idx, i)) + return stitch_cfgs, end_mapping_ids, num_stitches + +def paired_stitching(depth=12, kernel_size=2, stride=1): + blk_id = list(range(depth)) + i = 0 + stitch_cfgs = [] + stitch_id = -1 + stitching_layers_mappings = [] + + while i < depth: + ids = blk_id[i:i + kernel_size] + has_new_stitches = False + for j in ids: + for k in ids: + if (j, k) not in stitch_cfgs: + if j >= k: + continue + has_new_stitches = True + stitch_cfgs.append((j, k)) + stitching_layers_mappings.append(stitch_id + 1) + + if has_new_stitches: + stitch_id += 1 + + i += stride + + num_stitches = stitch_id + 1 + return stitch_cfgs, stitching_layers_mappings, num_stitches + + +def get_stitch_configs(depths, stage_id): + depths = sorted(depths) + + d = depths[0] + total_configs = [] + total_stitches = [] + + for i in range(1, len(depths)): + next_d = depths[i] + if next_d == d: + stitch_cfgs, layers_mappings, num_stitches = paired_stitching(d) + else: + stitch_cfgs, layers_mappings, num_stitches = unpaired_stitching(d, next_d) + comb = (i-1, i) + for cfg, layer_mapping_id in zip(stitch_cfgs, layers_mappings): + total_configs.append({ + 'comb_id': comb, + 'stage_id': stage_id, + 'stitch_cfgs': [cfg], + 'stitch_layers': [layer_mapping_id] + }) + total_stitches.append((num_stitches, comb)) + d = next_d + + return total_configs, total_stitches + +def rearrange_activations(activations): + n_channels = activations.shape[-1] + activations = activations.reshape(-1, n_channels) + return activations + +def ps_inv(x1, x2): + x1 = rearrange_activations(x1) + x2 = rearrange_activations(x2) + + if not x1.shape[0] == x2.shape[0]: + raise ValueError('Spatial size of compared neurons must match when ' \ + 'calculating psuedo inverse matrix.') + + # Get transformation matrix shape + shape = list(x1.shape) + shape[-1] += 1 + + # Calculate pseudo inverse + x1_ones = torch.ones(shape) + x1_ones[:, :-1] = x1 + A_ones = torch.matmul(torch.linalg.pinv(x1_ones), x2.to(x1_ones.device)).T + + # Get weights and bias + w = A_ones[..., :-1] + b = A_ones[..., -1] + + return w, b + + +def param_groups_weight_decay_stitch( + model: nn.Module, + weight_decay=1e-5, + no_weight_decay_list=(), + anchor_lr_scale=0.1, + stitch_lr_scale=0.1 +): + no_weight_decay_list = set(no_weight_decay_list) + decay = [] + no_decay = [] + stitching_weights = [] + stitching_bias = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: + if 'stitch' in name: + stitching_bias.append(param) + else: + no_decay.append(param) + else: + if 'stitch' in name: + stitching_weights.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0., 'lr_scale': anchor_lr_scale}, + {'params': decay, 'weight_decay': weight_decay, 'lr_scale': anchor_lr_scale}, + {'params': stitching_weights, 'weight_decay': weight_decay, 'lr_scale': stitch_lr_scale}, + {'params': stitching_bias, 'weight_decay': 0., 'lr_scale': stitch_lr_scale} + ] + + +def save_on_master_eval_res(log_stats, output_dir): + if is_main_process(): + with open(output_dir, 'a') as f: + f.write(json.dumps(log_stats) + "\n")