Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA support #512

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions basicsr/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,28 @@

from basicsr.models import lr_scheduler as lr_scheduler
from basicsr.utils import get_root_logger
from basicsr.utils import accelerator_util
from basicsr.utils.dist_util import master_only

try:
import torch_xla.core.xla_model as xm
except:
pass

class BaseModel():
"""Base model."""

def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.device = accelerator_util.default_device(opt)
self.is_train = opt['is_train']
self.schedulers = []
self.optimizers = []

@property
def accelerator(self):
return accelerator_util.accelerator_name(self.opt)

def feed_data(self, data):
pass

Expand Down Expand Up @@ -85,13 +94,18 @@ def get_current_log(self):
return self.log_dict

def model_to_device(self, net):
"""Model to device. It also warps models with DistributedDataParallel
"""Model to device. It also wraps models with DistributedDataParallel
or DataParallel.

Args:
net (nn.Module)
"""
net = net.to(self.device)

if self.accelerator == 'xla':
# No need to use DataParallel or DistributedDataParallel with xmp
return net

if self.opt['dist']:
find_unused_parameters = self.opt.get('find_unused_parameters', False)
net = DistributedDataParallel(
Expand All @@ -107,6 +121,13 @@ def get_optimizer(self, optim_type, params, lr, **kwargs):
raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
return optimizer

def optimizer_step(self, optimizer):
if self.accelerator == 'xla':
xm.optimizer_step(optimizer)
xm.mark_step()
else:
optimizer.step()

def setup_schedulers(self):
"""Set up schedulers."""
train_opt = self.opt['train']
Expand Down Expand Up @@ -361,7 +382,7 @@ def reduce_loss_dict(self, loss_dict):
loss_dict (OrderedDict): Loss dict.
"""
with torch.no_grad():
if self.opt['dist']:
if self.opt['dist'] and self.accelerator != 'xla':
keys = []
losses = []
for name, value in loss_dict.items():
Expand Down
2 changes: 1 addition & 1 deletion basicsr/models/sr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def optimize_parameters(self, current_iter):
loss_dict['l_style'] = l_style

l_total.backward()
self.optimizer_g.step()
self.optimizer_step(self.optimizer_g)

self.log_dict = self.reduce_loss_dict(loss_dict)

Expand Down
54 changes: 49 additions & 5 deletions basicsr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@
from basicsr.models import build_model
from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
from basicsr.utils.options import copy_opt_file, dict2str, parse_options
from basicsr.utils.options import copy_opt_file, dict2str, parse_options, preflight_options
from basicsr.utils import accelerator_util
from basicsr.utils import dist_util

try:
import torch_xla.core.xla_model as xm
except:
pass

def init_tb_loggers(opt):
# initialize wandb logger before tensorboard logger to allow proper sync
Expand Down Expand Up @@ -88,13 +94,22 @@ def load_resume_state(opt):
return resume_state


def train_pipeline(root_path):
def print_xla_device_info(opt):
if accelerator_util.accelerator_name(opt) == 'xla':
import torch_xla.core.xla_model as xm
print(f"XLA device: {xm.xla_device()} [{xm.xla_real_devices([xm.xla_device()])[0]}], ordinal: {xm.get_ordinal()}, replicas: {xm.xrt_world_size()}, is master: {xm.is_master_ordinal()}")

# These are expected to be the same
rank, world_size = dist_util.get_dist_info()
assert(rank == xm.get_ordinal())
assert(world_size == xm.xrt_world_size())

def _mp_train_pipeline(root_path):
# parse options, set distributed setting, set ramdom seed
opt, args = parse_options(root_path, is_train=True)
opt['root_path'] = root_path

torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
print_xla_device_info(opt)

# load resume states if necessary
resume_state = load_resume_state(opt)
Expand Down Expand Up @@ -139,6 +154,8 @@ def train_pipeline(root_path):
if prefetch_mode is None or prefetch_mode == 'cpu':
prefetcher = CPUPrefetcher(train_loader)
elif prefetch_mode == 'cuda':
if accelerator_util.accelerator_name(opt) != 'cuda':
raise ValueError(f"prefetch_mode cuda is not compatible with accelerator {accelerator_util.accelerator_name(opt)}.")
prefetcher = CUDAPrefetcher(train_loader, opt)
logger.info(f'Use {prefetch_mode} prefetch dataloader')
if opt['datasets']['train'].get('pin_memory') is not True:
Expand Down Expand Up @@ -196,7 +213,6 @@ def train_pipeline(root_path):
iter_timer.start()
train_data = prefetcher.next()
# end of iter

# end of epoch

consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
Expand All @@ -210,6 +226,34 @@ def train_pipeline(root_path):
tb_logger.close()


def _mp_train(rank, root_path):
_mp_train_pipeline(root_path)

def train_pipeline(root_path):
"""
Determines whether multiple processes need to be spawned, and then invoke _mp_train_pipeline.
This mode is meant to be used with XLA multiprocessing (but it should work in other environments).
However, it is incompatible with command-line multi-process launchers.
This is also an appropriate entry point to be invoked from Real-ESRGAN for XLA MP support.
"""
# Initial parse to determine whether we need to run under xmp multiprocessing
opt, args = preflight_options()

if opt.get('accelerator', 'cuda') == 'xla':
# We can't get the number of XLA devices because it would cause a replication error.
# We just assume xla multiprocessing is required, except if a launcher was used.
if args.launcher != "none":
raise ValueError(f"Launcher {args.launcher} is incompatible with XLA multiprocessing.")

import torch_xla.distributed.xla_multiprocessing as xmp
xmp.spawn(_mp_train, args=(root_path,), start_method='fork')
else:
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True

_mp_train_pipeline(root_path)


if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)
35 changes: 35 additions & 0 deletions basicsr/utils/accelerator_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
try:
import torch_xla.core.xla_model as xm
except:
pass

def accelerator_name(opt):
if opt['num_gpu'] == 0:
return 'cpu'
return opt.get('accelerator', 'cuda')

def default_device(opt):
accelerator = accelerator_name(opt)
if accelerator == 'xla':
return xm.xla_device()
return accelerator

def device_count(opt):
accelerator = accelerator_name(opt)
if accelerator == 'xla':
# Devices of the same hw family.
# Note: returns 1 when replication is in place!
# device = xm.xla_device()
# devices = xm.get_xla_supported_devices(xm.xla_device_hw(device))
# return len(devices)

# This works when replication is active
return xm.xrt_world_size()
return torch.cuda.device_count()

def use_xmp(opt):
accelerator = accelerator_name(opt)
if accelerator != 'xla':
return False
return device_count(opt) > 1
19 changes: 19 additions & 0 deletions basicsr/utils/dist_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
import torch.distributed as dist
import torch.multiprocessing as mp

_xmp_dist = None

class XMPDist():
def __init__(self):
super().__init__()

import torch_xla.core.xla_model as xm
self.rank = xm.get_ordinal()
self.world_size = xm.xrt_world_size()
self.is_master = xm.is_master_ordinal()

def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
Expand All @@ -17,6 +27,9 @@ def init_dist(launcher, backend='nccl', **kwargs):
else:
raise ValueError(f'Invalid launcher type: {launcher}')

def init_xmp():
global _xmp_dist
_xmp_dist = XMPDist()

def _init_dist_pytorch(backend, **kwargs):
rank = int(os.environ['RANK'])
Expand Down Expand Up @@ -58,6 +71,9 @@ def _init_dist_slurm(backend, port=None):


def get_dist_info():
if _xmp_dist is not None:
return _xmp_dist.rank, _xmp_dist.world_size

if dist.is_available():
initialized = dist.is_initialized()
else:
Expand All @@ -75,6 +91,9 @@ def master_only(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
if _xmp_dist is not None and _xmp_dist.is_master:
return func(*args, **kwargs)

rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
Expand Down
28 changes: 23 additions & 5 deletions basicsr/utils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from os import path as osp

from basicsr.utils import set_random_seed
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only

from basicsr.utils.accelerator_util import use_xmp, device_count
from basicsr.utils.dist_util import get_dist_info, init_dist, init_xmp, master_only

def ordered_yaml():
"""Support OrderedDict for yaml.
Expand Down Expand Up @@ -79,7 +79,17 @@ def _postprocess_yml_value(value):
return value


def parse_options(root_path, is_train=True):
def preflight_options():
"""
Just parse all the options for initial verification.

Attempting to access xla devices (such as trying to determine the number of
available devices, for instance) results in xmp not being able to create the
replication devices.

We use this function so the main training script can determine whether
we need to use XLA MP replication before proceeding further.
"""
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
Expand All @@ -94,8 +104,16 @@ def parse_options(root_path, is_train=True):
with open(args.opt, mode='r') as f:
opt = yaml.load(f, Loader=ordered_yaml()[0])

return opt, args

def parse_options(root_path, is_train=True):
opt, args = preflight_options()

# distributed settings
if args.launcher == 'none':
if use_xmp(opt):
opt['dist'] = True
init_xmp()
elif args.launcher == 'none':
opt['dist'] = False
print('Disable distributed.', flush=True)
else:
Expand Down Expand Up @@ -135,7 +153,7 @@ def parse_options(root_path, is_train=True):
opt['name'] = 'debug_' + opt['name']

if opt['num_gpu'] == 'auto':
opt['num_gpu'] = torch.cuda.device_count()
opt['num_gpu'] = device_count(opt)

# datasets
for phase, dataset in opt['datasets'].items():
Expand Down