From b92a266065f65a9bb56b6173493246df5f8a57b4 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Wed, 16 Oct 2024 08:03:29 +0000 Subject: [PATCH 1/3] add npu support --- src/open_clip_train/distributed.py | 6 ++++++ src/open_clip_train/main.py | 6 +++++- src/open_clip_train/params.py | 2 +- src/open_clip_train/profiler.py | 2 ++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/open_clip_train/distributed.py b/src/open_clip_train/distributed.py index 268a6c7ad..dd18cb643 100644 --- a/src/open_clip_train/distributed.py +++ b/src/open_clip_train/distributed.py @@ -107,6 +107,12 @@ def init_distributed_device(args): else: device = 'cuda:0' torch.cuda.set_device(device) + elif torch.npu.is_available(): + if args.distributed and not args.no_set_device_rank: + device = 'npu:%d' % args.local_rank + else: + device = "npu:0" + torch.npu.set_device(device) else: device = 'cpu' args.device = device diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index 591ea1d32..ff1cf2003 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -11,7 +11,6 @@ import numpy as np import torch from torch import optim -from torch.cuda.amp import GradScaler try: import wandb @@ -329,6 +328,11 @@ def main(args): hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) + if args.precision == "amp": + if torch.npu.is_available(): + from torch.npu.amp import GradScaler + else: + from torch.cuda.amp import GradScaler scaler = GradScaler() if args.precision == "amp" else None # optionally resume from a checkpoint diff --git a/src/open_clip_train/params.py b/src/open_clip_train/params.py index c3d19302d..829b63817 100644 --- a/src/open_clip_train/params.py +++ b/src/open_clip_train/params.py @@ -314,7 +314,7 @@ def parse_args(args): help="url used to set up distributed training", ) parser.add_argument( - "--dist-backend", default="nccl", type=str, help="distributed backend" + "--dist-backend", default="nccl", type=str, help="distributed backend. \"nccl\" for GPU, \"hccl\" for Ascend NPU" ) parser.add_argument( "--report-to", diff --git a/src/open_clip_train/profiler.py b/src/open_clip_train/profiler.py index 4e10a4c26..cd2a588d5 100644 --- a/src/open_clip_train/profiler.py +++ b/src/open_clip_train/profiler.py @@ -133,6 +133,8 @@ def profile_model(model_name, batch_size=1, profiler='torch'): model.eval() if torch.cuda.is_available(): model = model.cuda() + elif torch.npu.is_available(): + model = model.npu() if isinstance(model.visual.image_size, (tuple, list)): image_input_size = (3,) + tuple(model.visual.image_size[-2:]) From 0dee0b45c736bdd1e4e333ef87e59ca8d1b26799 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Thu, 17 Oct 2024 12:53:02 +0000 Subject: [PATCH 2/3] fix device check --- src/open_clip_train/distributed.py | 2 +- src/open_clip_train/main.py | 2 +- src/open_clip_train/params.py | 3 +++ src/open_clip_train/profiler.py | 6 +++--- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/open_clip_train/distributed.py b/src/open_clip_train/distributed.py index dd18cb643..1438c9f5f 100644 --- a/src/open_clip_train/distributed.py +++ b/src/open_clip_train/distributed.py @@ -107,7 +107,7 @@ def init_distributed_device(args): else: device = 'cuda:0' torch.cuda.set_device(device) - elif torch.npu.is_available(): + elif args.device == "npu" and torch.npu.is_available(): if args.distributed and not args.no_set_device_rank: device = 'npu:%d' % args.local_rank else: diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index ff1cf2003..1089ade53 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -329,7 +329,7 @@ def main(args): hvd.broadcast_optimizer_state(optimizer, root_rank=0) if args.precision == "amp": - if torch.npu.is_available(): + if args.device == "npu" and torch.npu.is_available(): from torch.npu.amp import GradScaler else: from torch.cuda.amp import GradScaler diff --git a/src/open_clip_train/params.py b/src/open_clip_train/params.py index 829b63817..c5b7804f2 100644 --- a/src/open_clip_train/params.py +++ b/src/open_clip_train/params.py @@ -306,6 +306,9 @@ def parse_args(args): parser.add_argument( "--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps." ) + parser.add_argument( + "--device", default="cuda", type=str, choices=["cpu", "cuda", "npu"], help="Accelerator to use." + ) # arguments for distributed training parser.add_argument( "--dist-url", diff --git a/src/open_clip_train/profiler.py b/src/open_clip_train/profiler.py index cd2a588d5..17c302201 100644 --- a/src/open_clip_train/profiler.py +++ b/src/open_clip_train/profiler.py @@ -125,7 +125,7 @@ def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_ def count_params(model): return sum(m.numel() for m in model.parameters()) -def profile_model(model_name, batch_size=1, profiler='torch'): +def profile_model(model_name, batch_size=1, profiler='torch', device="cuda"): assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported' if profiler == 'fvcore': assert fvcore is not None, 'Please install fvcore.' @@ -133,7 +133,7 @@ def profile_model(model_name, batch_size=1, profiler='torch'): model.eval() if torch.cuda.is_available(): model = model.cuda() - elif torch.npu.is_available(): + elif device == "npu" and torch.npu.is_available(): model = model.npu() if isinstance(model.visual.image_size, (tuple, list)): @@ -219,7 +219,7 @@ def main(): print('='*100) print(f'Profiling {m}') try: - row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler) + row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler, device=args.device) results.append(row) except Exception as e: print(f'Error profiling {m}: {e}') From acc40bb023e9da18404c8c7d7cdcec025c37a4fb Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Sat, 19 Oct 2024 13:38:58 +0000 Subject: [PATCH 3/3] update scaler --- src/open_clip_train/main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index 1089ade53..1aa0750fc 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -328,12 +328,12 @@ def main(args): hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) + scaler = None if args.precision == "amp": - if args.device == "npu" and torch.npu.is_available(): - from torch.npu.amp import GradScaler - else: - from torch.cuda.amp import GradScaler - scaler = GradScaler() if args.precision == "amp" else None + try: + scaler = torch.amp.GradScaler(device=device) + except (AttributeError, TypeError) as e: + scaler = torch.cuda.amp.GradScaler() # optionally resume from a checkpoint start_epoch = 0