From b7c2b81eae45c53ae5ee29c507083bbf5dddca74 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 22 Feb 2022 10:55:33 +0000 Subject: [PATCH 1/3] Basic XLA support (1 device only). Note: the call to the new `optimizer_step` method has only been added to `SRModel`. --- basicsr/models/base_model.py | 18 +++++++++++++++++- basicsr/models/sr_model.py | 2 +- basicsr/train.py | 4 +++- basicsr/utils/accelerator_util.py | 25 +++++++++++++++++++++++++ basicsr/utils/options.py | 4 ++-- 5 files changed, 48 insertions(+), 5 deletions(-) create mode 100644 basicsr/utils/accelerator_util.py diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index f06f9ca2c..9fc40a7a5 100644 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -7,19 +7,28 @@ from basicsr.models import lr_scheduler as lr_scheduler from basicsr.utils import get_root_logger +from basicsr.utils import accelerator_util from basicsr.utils.dist_util import master_only +try: + import torch_xla.core.xla_model as xm +except: + pass class BaseModel(): """Base model.""" def __init__(self, opt): self.opt = opt - self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.device = accelerator_util.default_device(opt) self.is_train = opt['is_train'] self.schedulers = [] self.optimizers = [] + @property + def accelerator(self): + return accelerator_util.accelerator_name(self.opt) + def feed_data(self, data): pass @@ -107,6 +116,13 @@ def get_optimizer(self, optim_type, params, lr, **kwargs): raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') return optimizer + def optimizer_step(self, optimizer): + if self.accelerator == 'xla': + xm.optimizer_step(optimizer) + xm.mark_step() + else: + optimizer.step() + def setup_schedulers(self): """Set up schedulers.""" train_opt = self.opt['train'] diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py index 54c80bd6d..784a3cc31 100644 --- a/basicsr/models/sr_model.py +++ b/basicsr/models/sr_model.py @@ -111,7 +111,7 @@ def optimize_parameters(self, current_iter): loss_dict['l_style'] = l_style l_total.backward() - self.optimizer_g.step() + self.optimizer_step(self.optimizer_g) self.log_dict = self.reduce_loss_dict(loss_dict) diff --git a/basicsr/train.py b/basicsr/train.py index f63149c64..c024ca7c0 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -12,7 +12,7 @@ from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir) from basicsr.utils.options import copy_opt_file, dict2str, parse_options - +from basicsr.utils.accelerator_util import accelerator_name def init_tb_loggers(opt): # initialize wandb logger before tensorboard logger to allow proper sync @@ -139,6 +139,8 @@ def train_pipeline(root_path): if prefetch_mode is None or prefetch_mode == 'cpu': prefetcher = CPUPrefetcher(train_loader) elif prefetch_mode == 'cuda': + if accelerator_name(opt) != 'cuda': + raise ValueError(f"prefetch_mode cuda is not compatible with accelerator {accelerator_name(opt)}.") prefetcher = CUDAPrefetcher(train_loader, opt) logger.info(f'Use {prefetch_mode} prefetch dataloader') if opt['datasets']['train'].get('pin_memory') is not True: diff --git a/basicsr/utils/accelerator_util.py b/basicsr/utils/accelerator_util.py new file mode 100644 index 000000000..644b9c3fb --- /dev/null +++ b/basicsr/utils/accelerator_util.py @@ -0,0 +1,25 @@ +import torch +try: + import torch_xla.core.xla_model as xm +except: + pass + + # self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + +def accelerator_name(opt): + if opt['num_gpu'] == 0: + return 'cpu' + return opt.get('accelerator', 'cuda') + +def default_device(opt): + accelerator = accelerator_name(opt) + if accelerator == 'xla': + return xm.xla_device() + return accelerator + +def device_count(opt): + accelerator = opt.get('accelerator', 'cuda') + if accelerator == 'xla': + device = xm.xla_device() + return 0 if device is None else 1 + return torch.cuda.device_count() \ No newline at end of file diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index 09bfa5a5b..e761fc5f9 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -6,9 +6,9 @@ from os import path as osp from basicsr.utils import set_random_seed +from basicsr.utils.accelerator_util import device_count from basicsr.utils.dist_util import get_dist_info, init_dist, master_only - def ordered_yaml(): """Support OrderedDict for yaml. @@ -135,7 +135,7 @@ def parse_options(root_path, is_train=True): opt['name'] = 'debug_' + opt['name'] if opt['num_gpu'] == 'auto': - opt['num_gpu'] = torch.cuda.device_count() + opt['num_gpu'] = device_count(opt) # datasets for phase, dataset in opt['datasets'].items(): From b59a97b8c2eaf81fe094b4d1c5f420b121bce596 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 23 Feb 2022 02:03:54 +0000 Subject: [PATCH 2/3] XLA MP support. This is still hackish and being tested. --- basicsr/models/base_model.py | 9 +++-- basicsr/train.py | 58 ++++++++++++++++++++++++++----- basicsr/utils/accelerator_util.py | 22 ++++++++---- basicsr/utils/dist_util.py | 19 ++++++++++ basicsr/utils/options.py | 27 +++++++++++--- 5 files changed, 115 insertions(+), 20 deletions(-) diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index 9fc40a7a5..c3c852485 100644 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -94,13 +94,18 @@ def get_current_log(self): return self.log_dict def model_to_device(self, net): - """Model to device. It also warps models with DistributedDataParallel + """Model to device. It also wraps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) + + if self.accelerator == 'xla': + # No need to use DataParallel or DistributedDataParallel with xmp + return net + if self.opt['dist']: find_unused_parameters = self.opt.get('find_unused_parameters', False) net = DistributedDataParallel( @@ -377,7 +382,7 @@ def reduce_loss_dict(self, loss_dict): loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): - if self.opt['dist']: + if self.opt['dist'] and self.accelerator != 'xla': keys = [] losses = [] for name, value in loss_dict.items(): diff --git a/basicsr/train.py b/basicsr/train.py index c024ca7c0..1e4649ff5 100644 --- a/basicsr/train.py +++ b/basicsr/train.py @@ -11,8 +11,14 @@ from basicsr.models import build_model from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir) -from basicsr.utils.options import copy_opt_file, dict2str, parse_options -from basicsr.utils.accelerator_util import accelerator_name +from basicsr.utils.options import copy_opt_file, dict2str, parse_options, preflight_options +from basicsr.utils import accelerator_util +from basicsr.utils import dist_util + +try: + import torch_xla.core.xla_model as xm +except: + pass def init_tb_loggers(opt): # initialize wandb logger before tensorboard logger to allow proper sync @@ -88,13 +94,22 @@ def load_resume_state(opt): return resume_state -def train_pipeline(root_path): +def print_xla_device_info(opt): + if accelerator_util.accelerator_name(opt) == 'xla': + import torch_xla.core.xla_model as xm + print(f"XLA device: {xm.xla_device()} [{xm.xla_real_devices([xm.xla_device()])[0]}], ordinal: {xm.get_ordinal()}, replicas: {xm.xrt_world_size()}, is master: {xm.is_master_ordinal()}") + + # These are expected to be the same + rank, world_size = dist_util.get_dist_info() + assert(rank == xm.get_ordinal()) + assert(world_size == xm.xrt_world_size()) + +def _mp_train_pipeline(root_path): # parse options, set distributed setting, set ramdom seed opt, args = parse_options(root_path, is_train=True) opt['root_path'] = root_path - torch.backends.cudnn.benchmark = True - # torch.backends.cudnn.deterministic = True + print_xla_device_info(opt) # load resume states if necessary resume_state = load_resume_state(opt) @@ -139,8 +154,8 @@ def train_pipeline(root_path): if prefetch_mode is None or prefetch_mode == 'cpu': prefetcher = CPUPrefetcher(train_loader) elif prefetch_mode == 'cuda': - if accelerator_name(opt) != 'cuda': - raise ValueError(f"prefetch_mode cuda is not compatible with accelerator {accelerator_name(opt)}.") + if accelerator_util.accelerator_name(opt) != 'cuda': + raise ValueError(f"prefetch_mode cuda is not compatible with accelerator {accelerator_util.accelerator_name(opt)}.") prefetcher = CUDAPrefetcher(train_loader, opt) logger.info(f'Use {prefetch_mode} prefetch dataloader') if opt['datasets']['train'].get('pin_memory') is not True: @@ -198,7 +213,6 @@ def train_pipeline(root_path): iter_timer.start() train_data = prefetcher.next() # end of iter - # end of epoch consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) @@ -212,6 +226,34 @@ def train_pipeline(root_path): tb_logger.close() +def _mp_train(rank, root_path): + _mp_train_pipeline(root_path) + +def train_pipeline(root_path): + """ + Determines whether multiple processes need to be spawned, and then invoke _mp_train_pipeline. + This mode is meant to be used with XLA multiprocessing (but it should work in other environments). + However, it is incompatible with command-line multi-process launchers. + This is also an appropriate entry point to be invoked from Real-ESRGAN for XLA MP support. + """ + # Initial parse to determine whether we need to run under xmp multiprocessing + opt, args = preflight_options() + + if opt.get('accelerator', 'cuda') == 'xla': + # We can't get the number of XLA devices because it would cause a replication error. + # We just assume xla multiprocessing is required, except if a launcher was used. + if args.launcher != "none": + raise ValueError(f"Launcher {args.launcher} is incompatible with XLA multiprocessing.") + + import torch_xla.distributed.xla_multiprocessing as xmp + xmp.spawn(_mp_train, args=(root_path,), start_method='fork') + else: + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + _mp_train_pipeline(root_path) + + if __name__ == '__main__': root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) train_pipeline(root_path) diff --git a/basicsr/utils/accelerator_util.py b/basicsr/utils/accelerator_util.py index 644b9c3fb..9290d2c90 100644 --- a/basicsr/utils/accelerator_util.py +++ b/basicsr/utils/accelerator_util.py @@ -4,8 +4,6 @@ except: pass - # self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') - def accelerator_name(opt): if opt['num_gpu'] == 0: return 'cpu' @@ -18,8 +16,20 @@ def default_device(opt): return accelerator def device_count(opt): - accelerator = opt.get('accelerator', 'cuda') + accelerator = accelerator_name(opt) if accelerator == 'xla': - device = xm.xla_device() - return 0 if device is None else 1 - return torch.cuda.device_count() \ No newline at end of file + # Devices of the same hw family. + # Note: returns 1 when replication is in place! + # device = xm.xla_device() + # devices = xm.get_xla_supported_devices(xm.xla_device_hw(device)) + # return len(devices) + + # This works when replication is active + return xm.xrt_world_size() + return torch.cuda.device_count() + +def use_xmp(opt): + accelerator = accelerator_name(opt) + if accelerator != 'xla': + return False + return device_count(opt) > 1 \ No newline at end of file diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py index 0fab887b2..684749238 100644 --- a/basicsr/utils/dist_util.py +++ b/basicsr/utils/dist_util.py @@ -6,6 +6,16 @@ import torch.distributed as dist import torch.multiprocessing as mp +_xmp_dist = None + +class XMPDist(): + def __init__(self): + super().__init__() + + import torch_xla.core.xla_model as xm + self.rank = xm.get_ordinal() + self.world_size = xm.xrt_world_size() + self.is_master = xm.is_master_ordinal() def init_dist(launcher, backend='nccl', **kwargs): if mp.get_start_method(allow_none=True) is None: @@ -17,6 +27,9 @@ def init_dist(launcher, backend='nccl', **kwargs): else: raise ValueError(f'Invalid launcher type: {launcher}') +def init_xmp(): + global _xmp_dist + _xmp_dist = XMPDist() def _init_dist_pytorch(backend, **kwargs): rank = int(os.environ['RANK']) @@ -58,6 +71,9 @@ def _init_dist_slurm(backend, port=None): def get_dist_info(): + if _xmp_dist is not None: + return _xmp_dist.rank, _xmp_dist.world_size + if dist.is_available(): initialized = dist.is_initialized() else: @@ -75,6 +91,9 @@ def master_only(func): @functools.wraps(func) def wrapper(*args, **kwargs): + if _xmp_dist is not None and _xmp_dist.is_master: + return func(*args, **kwargs) + rank, _ = get_dist_info() if rank == 0: return func(*args, **kwargs) diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index e761fc5f9..1ce5ea908 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -6,8 +6,8 @@ from os import path as osp from basicsr.utils import set_random_seed -from basicsr.utils.accelerator_util import device_count -from basicsr.utils.dist_util import get_dist_info, init_dist, master_only +from basicsr.utils.accelerator_util import use_xmp, device_count +from basicsr.utils.dist_util import get_dist_info, init_dist, init_xmp, master_only def ordered_yaml(): """Support OrderedDict for yaml. @@ -79,7 +79,17 @@ def _postprocess_yml_value(value): return value -def parse_options(root_path, is_train=True): +def preflight_options(): + """ + Just parse all the options for initial verification. + + Attempting to access xla devices (such as trying to determine the number of + available devices, for instance) results in xmp not being able to create the + replication devices. + + We use this function so the main training script can determine whether + we need to use XLA MP replication. + """ parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') @@ -94,8 +104,17 @@ def parse_options(root_path, is_train=True): with open(args.opt, mode='r') as f: opt = yaml.load(f, Loader=ordered_yaml()[0]) + return opt, args + +def parse_options(root_path, is_train=True): + opt, args = preflight_options() + # distributed settings - if args.launcher == 'none': + if use_xmp(opt): + # xmp parallelism is handled specially, we don't need opt['dist'] + opt['dist'] = True + init_xmp() + elif args.launcher == 'none': opt['dist'] = False print('Disable distributed.', flush=True) else: From 36dca17f2dc637c4b932c4c7e785092a2901ef8d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 23 Feb 2022 17:29:11 +0000 Subject: [PATCH 3/3] Remove confusing comment, reword another. --- basicsr/utils/options.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py index 1ce5ea908..6f0ef1823 100644 --- a/basicsr/utils/options.py +++ b/basicsr/utils/options.py @@ -88,7 +88,7 @@ def preflight_options(): replication devices. We use this function so the main training script can determine whether - we need to use XLA MP replication. + we need to use XLA MP replication before proceeding further. """ parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') @@ -111,7 +111,6 @@ def parse_options(root_path, is_train=True): # distributed settings if use_xmp(opt): - # xmp parallelism is handled specially, we don't need opt['dist'] opt['dist'] = True init_xmp() elif args.launcher == 'none':