From b92a266065f65a9bb56b6173493246df5f8a57b4 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Wed, 16 Oct 2024 08:03:29 +0000 Subject: [PATCH] add npu support --- src/open_clip_train/distributed.py | 6 ++++++ src/open_clip_train/main.py | 6 +++++- src/open_clip_train/params.py | 2 +- src/open_clip_train/profiler.py | 2 ++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/open_clip_train/distributed.py b/src/open_clip_train/distributed.py index 268a6c7ad..dd18cb643 100644 --- a/src/open_clip_train/distributed.py +++ b/src/open_clip_train/distributed.py @@ -107,6 +107,12 @@ def init_distributed_device(args): else: device = 'cuda:0' torch.cuda.set_device(device) + elif torch.npu.is_available(): + if args.distributed and not args.no_set_device_rank: + device = 'npu:%d' % args.local_rank + else: + device = "npu:0" + torch.npu.set_device(device) else: device = 'cpu' args.device = device diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index 591ea1d32..ff1cf2003 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -11,7 +11,6 @@ import numpy as np import torch from torch import optim -from torch.cuda.amp import GradScaler try: import wandb @@ -329,6 +328,11 @@ def main(args): hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) + if args.precision == "amp": + if torch.npu.is_available(): + from torch.npu.amp import GradScaler + else: + from torch.cuda.amp import GradScaler scaler = GradScaler() if args.precision == "amp" else None # optionally resume from a checkpoint diff --git a/src/open_clip_train/params.py b/src/open_clip_train/params.py index c3d19302d..829b63817 100644 --- a/src/open_clip_train/params.py +++ b/src/open_clip_train/params.py @@ -314,7 +314,7 @@ def parse_args(args): help="url used to set up distributed training", ) parser.add_argument( - "--dist-backend", default="nccl", type=str, help="distributed backend" + "--dist-backend", default="nccl", type=str, help="distributed backend. \"nccl\" for GPU, \"hccl\" for Ascend NPU" ) parser.add_argument( "--report-to", diff --git a/src/open_clip_train/profiler.py b/src/open_clip_train/profiler.py index 4e10a4c26..cd2a588d5 100644 --- a/src/open_clip_train/profiler.py +++ b/src/open_clip_train/profiler.py @@ -133,6 +133,8 @@ def profile_model(model_name, batch_size=1, profiler='torch'): model.eval() if torch.cuda.is_available(): model = model.cuda() + elif torch.npu.is_available(): + model = model.npu() if isinstance(model.visual.image_size, (tuple, list)): image_input_size = (3,) + tuple(model.visual.image_size[-2:])