Skip to content

Commit

Permalink
add npu support
Browse files Browse the repository at this point in the history
  • Loading branch information
MengqingCao committed Oct 16, 2024
1 parent 6de2025 commit b92a266
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/open_clip_train/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def init_distributed_device(args):
else:
device = 'cuda:0'
torch.cuda.set_device(device)
elif torch.npu.is_available():
if args.distributed and not args.no_set_device_rank:
device = 'npu:%d' % args.local_rank
else:
device = "npu:0"
torch.npu.set_device(device)
else:
device = 'cpu'
args.device = device
Expand Down
6 changes: 5 additions & 1 deletion src/open_clip_train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
import torch
from torch import optim
from torch.cuda.amp import GradScaler

try:
import wandb
Expand Down Expand Up @@ -329,6 +328,11 @@ def main(args):
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

if args.precision == "amp":
if torch.npu.is_available():
from torch.npu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
scaler = GradScaler() if args.precision == "amp" else None

# optionally resume from a checkpoint
Expand Down
2 changes: 1 addition & 1 deletion src/open_clip_train/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def parse_args(args):
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
"--dist-backend", default="nccl", type=str, help="distributed backend. \"nccl\" for GPU, \"hccl\" for Ascend NPU"
)
parser.add_argument(
"--report-to",
Expand Down
2 changes: 2 additions & 0 deletions src/open_clip_train/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def profile_model(model_name, batch_size=1, profiler='torch'):
model.eval()
if torch.cuda.is_available():
model = model.cuda()
elif torch.npu.is_available():
model = model.npu()

if isinstance(model.visual.image_size, (tuple, list)):
image_input_size = (3,) + tuple(model.visual.image_size[-2:])
Expand Down

0 comments on commit b92a266

Please sign in to comment.