Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use MPS and explicitly disable autocast & GradScaler for non-CUDA #654

Closed
wants to merge 2 commits into from

Conversation

EIFY
Copy link
Contributor

@EIFY EIFY commented Oct 4, 2023

device = "mps" makes both training and inference ~ an order of magnitude faster on newer Macs, torch==2.2.0.dev20231002:

Training: MPS vs. CPU, RN50 model (username redacted)

python3 -m training.main \
    --report-to wandb \
    --name MPS-batch-512 \
    --save-frequency 1 \
    --train-data="/Users/█/Downloads/redcaps_v1.0_annotations/tarfiles/redcaps_combined_{00000000..00000585}.tar" \
    --train-num-samples 585668 \
    --dataset-type webdataset \
    --warmup 1000 \
    --batch-size=512 \
    --lr=5e-4 \
    --wd=0.1 \
    --epochs=5 \
    --workers=0 \
    --model RN50
Screenshot 2023-10-04 at 1 35 32 PM

Inference:


time python3 -m training.main \
    --imagenet-val="/Users/█/Downloads/ImageNetV2-val/" \
    --model RN50 \
    --pretrained logs/MPS-batch-512/checkpoints/epoch_5.pt \

mps:

2023-10-04,13:00:17 | INFO | Starting zero-shot imagenet.
2023-10-04,13:00:17 | INFO | Building zero-shot classifier
2023-10-04,13:00:17 | WARNING | MPS devices do not support AMP yet: https://github.com/pytorch/pytorch/issues/88415 Disabling AMP and falling back to FP32.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:19<00:00,  5.03it/s]
2023-10-04,13:03:36 | INFO | Using classifier
2023-10-04,13:03:36 | WARNING | MPS devices do not support AMP yet: https://github.com/pytorch/pytorch/issues/88415 Disabling AMP and falling back to FP32.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50048/50048 [02:46<00:00, 300.65it/s]
2023-10-04,13:06:23 | INFO | Finished zero-shot imagenet.
2023-10-04,13:06:23 | WARNING | MPS devices do not support AMP yet: https://github.com/pytorch/pytorch/issues/88415 Disabling AMP and falling back to FP32.
2023-10-04,13:06:23 | INFO | Eval Epoch: 0 imagenet-zeroshot-val-top1: 0.0282   imagenet-zeroshot-val-top5: 0.0905
[2023-10-04 13:06:23,130] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2023-10-04 13:06:23,130] torch._dynamo.utils: [INFO] Function, Runtimes (s)

real    6m12.258s
user    4m51.843s
sys 0m49.552s

cpu:

2023-10-04,11:28:18 | INFO | Starting zero-shot imagenet.
2023-10-04,11:28:18 | INFO | Building zero-shot classifier
2023-10-04,11:28:18 | WARNING | CPU devices have limited AMP support and result in attn_mask.dtype: float and query.dtype: c10::BFloat16 mismatch. Disabling AMP and falling back to FP32.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [09:33<00:00,  1.75it/s]
2023-10-04,11:37:51 | INFO | Using classifier
2023-10-04,11:37:51 | WARNING | CPU devices have limited AMP support and result in attn_mask.dtype: float and query.dtype: c10::BFloat16 mismatch. Disabling AMP and falling back to FP32.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50048/50048 [1:06:30<00:00, 12.54it/s]
2023-10-04,12:44:21 | INFO | Finished zero-shot imagenet.
2023-10-04,12:44:21 | WARNING | CPU devices have limited AMP support and result in attn_mask.dtype: float and query.dtype: c10::BFloat16 mismatch. Disabling AMP and falling back to FP32.
2023-10-04,12:44:21 | INFO | Eval Epoch: 0 imagenet-zeroshot-val-top1: 0.0282   imagenet-zeroshot-val-top5: 0.0905
[2023-10-04 12:44:21,412] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2023-10-04 12:44:21,412] torch._dynamo.utils: [INFO] Function, Runtimes (s)

real    76m8.598s
user    396m31.792s
sys 35m38.153s

This PR also includes explicit handling of autocast & GradScaler for non-CUDA devices. Currently open_clip is hardcoded to use the CUDA version (torch.cuda.amp.autocast and torch.cuda.amp.GradScaler respectively), which get disabled with warnings like

/Users/█/Library/Python/3.10/lib/python/site-packages/torch/amp/autocast_mode.py:250: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
  warnings.warn(
/Users/█/Library/Python/3.10/lib/python/site-packages/torch/cuda/amp/grad_scaler.py:124: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.
  warnings.warn(

With this PR we issue our own warnings and explain the rationales & consequences:

  1. AMP support for MPS devices is tracked with Enable AMP for MPS devices pytorch/pytorch#88415
  2. We do have torch.cpu.amp.autocast but training with it failed with the stated attn_mask & query dtype mismatch.

@EIFY
Copy link
Contributor Author

EIFY commented Oct 7, 2023

The issue with torch.cpu.amp.autocast is probably the same as pytorch/pytorch#107663

@EIFY
Copy link
Contributor Author

EIFY commented Oct 16, 2023

Hmm, I don't know who is best suited to review this PR or who else is interested in running open_clip on M1/M2 Macs for that matter 🤔 @gabrielilharco Could you take a look?

@gabrielilharco
Copy link
Collaborator

Sorry, I don't have access to that kind of hardware so can't test it myself @EIFY

@EIFY
Copy link
Contributor Author

EIFY commented Oct 17, 2023

@gabrielilharco Do you know if any of the owners do? If not, can I get an external M1/M2 Mac user to endorse instead?

@rwightman
Copy link
Collaborator

So, supporting mps and other non cuda/cpu devices worthwhile goal, not sure 'this' is the best approach though.

For autocast, should we rely on the amp (precision) arg to determine whether or not to try to use autocast? If autocast is used with mps it should crash instead of falling back (in my opinion), so that it's more clear it doesn't work.

For the initialization of device, probably better to explicitly pass a device str to the fn that will be sensibly merged with the distributed env. For mps distributed doesn't make sense, but I wouldn't say we want to default to mps if mps is available? it's a tossup on m1 if you want to use it vs the CPU, we should likely err towards being explicit rather than implict here...

@EIFY
Copy link
Contributor Author

EIFY commented Oct 21, 2023

If autocast is used with mps it should crash instead of falling back (in my opinion), so that it's more clear it doesn't work.

@rwightman Falling back is the current behavior for both autocast & grad_scaler, see these two warnings:

/Users/█/Library/Python/3.10/lib/python/site-packages/torch/amp/autocast_mode.py:250: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
  warnings.warn(
/Users/█/Library/Python/3.10/lib/python/site-packages/torch/cuda/amp/grad_scaler.py:124: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.
  warnings.warn(

but I actually agree: I would rather the training code to crash when --precision=amp on non-cuda.

For the initialization of device, probably better to explicitly pass a device str to the fn that will be sensibly merged with the distributed env. For mps distributed doesn't make sense, but I wouldn't say we want to default to mps if mps is available? it's a tossup on m1 if you want to use it vs the CPU, we should likely err towards being explicit rather than implict here...

Similarly for the current handling of cuda vs. cpu: Right now it's falling back to cpu if cuda isn't available.
So...maybe make a clean break and consistently handle by crashing? i.e.

  1. If precision is amp* and the device is non-cuda, crash. Leave autocast & grad_scaler as they are.
  2. Add a --device param that defaults to cuda, and whether it's cuda or mps, we always crash if the specified device isn't available.

@EIFY EIFY mentioned this pull request Oct 22, 2023
@rwightman
Copy link
Collaborator

@EIFY been a while but torch has finally improved some things related to this, #965 should enable mps and other 3rd party devices to work, incl w/ AMP autocast & grad scaler w/o jumping through as many hoops

@rwightman rwightman closed this Oct 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants