forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AmpKernels.cu
248 lines (221 loc) · 11.9 KB
/
AmpKernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
#define _USE_MATH_DEFINES
#include <math.h>
#include <ATen/ATen.h>
#include <ATen/DeviceGuard.h>
#include <ATen/Dispatch.h>
#include <ATen/native/cuda/ForeachFunctors.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/TensorIterator.h>
namespace {
// Thin wrapper around https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE_1g57a3c8313f570282a1a7bcc78743b08e,
// to ensure the Cuda math library's isfinite is actually what gets called in
// _amp_non_finite_check_and_unscale_cuda_'s gpu_kernel lambda.
//
// isfinite_ensure_cuda_math is defined outside at::native because:
// - A bare call to "isfinite(val)" inside at::native causes nvcc to prefer the unrelated
// Tensor at::native::isfinite(const Tensor&), resulting in an error:
// "no suitable constructor exists to convert from "float" to "at::Tensor""
// - Unfortunately, the Cuda math library documentation doesn't say how (or if) you can provide a full namespace path
// to ensure that its version of a particular function is invoked. It only shows bare (not-namespaced)
// calls to its routines inside kernel or device functions.
// - "std::isfinite(val)" in the gpu_kernel lambda causes an "unspecified launch failure" at runtime with cuda 9 on Windows.
//
// isfinite_ensure_cuda_math, declared at file scope outside the at::native region, uses isfinite as math library docs
// suggest and allows disambiguated usage in the lambda within the at::native region.
// GPU_LAMBDA is defined as __host__ __device__ (see Loops.cuh), so I need the __host__ keyword or else nvcc complains that
// "calling a __device__ function("isfinite_ensure_cuda_math") from a __host__ __device__ function("operator()") is not allowed."
static __host__ __device__ __forceinline__ int isfinite_ensure_cuda_math(float val) {
return isfinite(val);
}
}
namespace at {
namespace native {
namespace {
// Single-tensor fallback for _amp_foreach_non_finite_check_and_unscale_cuda_.
// Handles individual tensors that are acceptable to unscale but not MTA-safe.
void _amp_non_finite_check_and_unscale_cuda_(Tensor& scaled_grad,
Tensor& found_inf,
const Tensor& inv_scale)
{
// The only way we reach this function is through _amp_foreach_non_finite_check_and_unscale_cuda_, so no input checks.
// It's not obvious gpu_kernel always guards onto its argument. Guarding here just in case.
const OptionalDeviceGuard device_guard(device_of(scaled_grad));
// Acts on scaled_grad in place.
auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
iter.dtype(),
"_amp_non_finite_check_and_unscale_cuda",
[&iter, &found_inf, &inv_scale] {
auto* found_inf_ptr = found_inf.data_ptr<float>();
auto* inv_scale_ptr = inv_scale.data_ptr<float>();
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
gpu_kernel(iter,
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (scalar_t val_in) -> scalar_t {
auto val = static_cast<opmath_t>(val_in);
if (!isfinite_ensure_cuda_math(val)) {
*found_inf_ptr = 1.f;
}
// Every thread accesses inv_scale, but it will hit in cache.
const auto inv_scale_val = *inv_scale_ptr;
return static_cast<scalar_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
});
});
}
} // anonymous namespace
// Multiplies each tensor in scaled_grads by inv_scale in-place.
// If any element of any tensor in scaled_grads is inf or NaN, sets found_inf to 1.0.
// Uses multi tensor apply (MTA) to process all MTA-safe tensors.
//
// Args:
// scaled_grads: A TensorList of scaled gradient tensors. May contain infs or NaNs.
// found_inf: A single-element float tensor to which 1.0 will be written if any gradient contain infs/nans.
// Pre-zeroing found_inf, if appropriate, is the responsibility of the caller.
// inv_scale: The inverse of the scale factor by which scaled_grads are currently multiplied.
void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,
Tensor& found_inf,
const Tensor& inv_scale)
{
if (scaled_grads.size() == 0) {
return;
}
TORCH_CHECK(inv_scale.is_cuda(), "inv_scale must be a CUDA tensor.");
TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor.");
TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor.");
TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor.");
// Ensures client code (GradScaler) filtered scaled_grads by dtype.
check_foreach_api_restrictions(scaled_grads);
std::vector<std::vector<at::Tensor>> tensor_lists;
// is_non_overlapping_and_dense() is not available in Python.
// GradScaler can't filter for it. We need to filter here.
if (can_use_fast_route(scaled_grads)) {
// Hopefully common case.
// can_use_fast_route is true, which confirms:
// - all scaled_grads are strided
// - all scaled_grads are non overlapping and dense
// - all scaled_grads are on the same device
// - all scaled_grads are of the same dtype
TORCH_CHECK(scaled_grads[0].is_cuda(), "scaled_grads must be CUDA tensors.");
// Sets up MTA launch to use scaled_grads as-is.
tensor_lists.emplace_back(scaled_grads.vec());
} else {
// Hopefully uncommon case.
// can_use_fast_route is an all-or-nothing check. In this path it was false,
// so any of the above confirmations could have gone wrong.
// We filter MTA-safe tensors into an MTA-able list.
// If a tensor is acceptable but not MTA-safe, we fall back to the TensorIterator kernel.
// If a tensor is unacceptable, we throw an error to blame GradScaler.
tensor_lists.resize(1);
tensor_lists[0].reserve(scaled_grads.size());
auto expected_device = scaled_grads[0].device();
const auto expected_dtype = scaled_grads[0].scalar_type();
for (const Tensor& t : scaled_grads) {
// Ensures GradScaler filtered scaled_grads by device.
TORCH_CHECK(t.is_cuda(), "one of scaled_grads was not a CUDA tensor.");
TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device.");
TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor.");
if (!t.is_non_overlapping_and_dense() || t.scalar_type() != expected_dtype) {
// t is acceptable but not MTA-safe. Falls back to single-tensor TensorIterator kernel.
_amp_non_finite_check_and_unscale_cuda_(const_cast<Tensor&>(t),
found_inf,
inv_scale);
} else {
tensor_lists[0].push_back(t);
}
}
if (tensor_lists[0].size() == 0) {
return;
}
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
tensor_lists[0][0].scalar_type(),
"_amp_foreach_non_finite_check_and_unscale_cuda",
[&tensor_lists, &found_inf, &inv_scale] {
auto* found_inf_ptr = found_inf.data_ptr<float>();
auto* inv_scale_ptr = inv_scale.data_ptr<float>();
using opmath_t = get_opmath_t<scalar_t>::opmath_t;
// multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly.
multi_tensor_apply<1>(tensor_lists,
UnaryOpFunctor<scalar_t,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
// There is a slight asymmetry here with the TensorIterator kernel above.
// MTA Functors ensure val comes in as opmath_t rather than scalar_t.
if (!isfinite_ensure_cuda_math(val)) {
*found_inf_ptr = 1.f;
}
// Every thread accesses inv_scale, but it will hit in cache.
const auto inv_scale_val = *inv_scale_ptr;
return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
});
});
}
// amp_update_scale_cuda_kernel is launched with a single thread to compute the new scale.
// The scale factor is maintained and updated on the GPU to avoid synchronization.
__global__ void amp_update_scale_cuda_kernel(float* current_scale,
int* growth_tracker,
float* found_inf,
double growth_factor,
double backoff_factor,
int growth_interval)
{
if (*found_inf) {
*current_scale = (*current_scale)*backoff_factor;
*growth_tracker = 0;
} else {
// Entering this branch means we just carried out a successful step,
// so growth_tracker is incremented before comparing to growth_interval.
auto successful = (*growth_tracker) + 1;
if (successful == growth_interval) {
*current_scale = (*current_scale)*growth_factor;
*growth_tracker = 0;
} else {
*growth_tracker = successful;
}
}
}
// _amp_update_scale_cuda asynchronously updates the scale tensor in place.
//
// Args:
// current_scale: A one-element cuda float tensor containing the scale value.
// growth_tracker: A one-element torch.cuda.IntTensor containing the number of recent consecutive unskipped steps.
// found_inf: A one-element cuda float tensor. If > 0, indicates that infs/nans were found by the relevant
// prior _amp_non_finite_check_and_unscale_cuda call, and 0 if no infs/nans were found.
// growth_factor: Multiplier if no infs/NaNs were found (typically slightly > 1).
// backoff_factor: Multiplier if infs/NaNs were found (typically 0.5).
// growth_interval: Number of consecutive unskipped steps that must occur for current_scale to be multiplied by
// growth_factor.
//
// Returns:
// current_scale
Tensor& _amp_update_scale_cuda_(Tensor& current_scale,
Tensor& growth_tracker,
const Tensor& found_inf,
double growth_factor,
double backoff_factor,
int64_t growth_interval)
{
TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
TORCH_CHECK(growth_tracker.numel() == 1, "growth_tracker must be a 1-element tensor.");
TORCH_CHECK(current_scale.numel() == 1, "current_scale must be a 1-element tensor.");
TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor.");
TORCH_CHECK(growth_tracker.scalar_type() == at::ScalarType::Int, "growth_tracker must be an int tensor.");
TORCH_CHECK(current_scale.scalar_type() == at::ScalarType::Float, "current_scale must be a float tensor.");
TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor.");
amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
current_scale.data_ptr<float>(),
growth_tracker.data_ptr<int>(),
found_inf.data_ptr<float>(),
growth_factor,
backoff_factor,
growth_interval);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return current_scale;
}
}} // namespace at::native