forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_functionalization_of_rng_ops.py
350 lines (270 loc) · 11.3 KB
/
test_functionalization_of_rng_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# Owner(s): ["oncall: pt2"]
import sys
import unittest
import torch
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
from functorch.compile import aot_function, nop, min_cut_rematerialization_partition
from unittest.mock import patch
import functools
import torch.utils.checkpoint
from torch.testing._internal.common_utils import (
IS_CI,
IS_WINDOWS,
)
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"torch.compile not supported on windows"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("torch.compile not supported on windows")
def count_philox_rand(gm, args, freq):
assert [node.target for node in gm.graph.nodes].count(torch.ops.rngprims.philox_rand.default) == freq
return gm
class TestFunctionalizationRngOps(TestCase):
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_rand_like(self, dtype, device):
def fn(x):
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
x = torch.rand(10, device=device, dtype=dtype)
for seed in range(10):
torch.cuda.manual_seed(seed)
ref = fn(x)
torch.cuda.manual_seed(seed)
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
res = aot_fn(x)
self.assertEqual(ref, res)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_rand_like_dynamic(self, dtype, device):
def fn(x):
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
for seed in range(1, 10):
shape = (seed, seed)
x = torch.rand(shape, device=device, dtype=dtype)
torch.cuda.manual_seed(seed)
ref = fn(x)
torch.cuda.manual_seed(seed)
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
res = opt_fn(x)
self.assertEqual(ref, res)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_rand_like_dynamic_bwd(self, dtype, device):
def fn(x):
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
for seed in range(1, 10):
shape = (seed, seed)
x = torch.rand(shape, device=device, dtype=dtype, requires_grad=True)
torch.cuda.manual_seed(seed)
ref = fn(x)
ref.sum().backward()
torch.cuda.manual_seed(seed)
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
res = opt_fn(x)
res.sum().backward()
self.assertEqual(ref, res)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_rand(self, dtype, device):
shape = (10,)
def fn(x):
a = torch.rand(*shape, device=device, dtype=dtype) * x
a = torch.rand(*shape, device=device, dtype=dtype) * a
return a
x = torch.rand(*shape, device=device, dtype=dtype)
for seed in range(10):
torch.cuda.manual_seed(seed)
ref = fn(x)
torch.cuda.manual_seed(seed)
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
res = aot_fn(x)
self.assertEqual(ref, res)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_autograd_function(self, dtype, device):
shape = (16, 16)
class Custom(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
custom = Custom.apply
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
x_clone = x.clone().detach().requires_grad_(True)
torch.cuda.manual_seed(123)
ref = custom(x)
ref.sum().backward()
torch.cuda.manual_seed(123)
fwd_compiler = functools.partial(count_philox_rand, freq=2)
bwd_compiler = functools.partial(count_philox_rand, freq=1)
aot_custom = aot_function(custom, fwd_compiler, bwd_compiler)
res = aot_custom(x_clone)
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_multiple_subgraphs(self, dtype, device):
# Checks that rng state is maintained when there are multiple aot traced
# graphs.
shape = (16, 16)
class CustomOp1(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
class CustomOp2(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
a = torch.rand_like(x) * x
return a
@staticmethod
def backward(ctx, grad_out):
x, = ctx.saved_tensors
return grad_out * torch.rand_like(grad_out) * torch.rand_like(x)
custom_op1 = CustomOp1.apply
custom_op2 = CustomOp2.apply
def fn(x):
a = custom_op1(x)
b = a.sin()
return custom_op2(b)
fwd_compiler = functools.partial(count_philox_rand, freq=2)
bwd_compiler = functools.partial(count_philox_rand, freq=1)
aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler)
fwd_compiler = functools.partial(count_philox_rand, freq=1)
bwd_compiler = functools.partial(count_philox_rand, freq=2)
aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler)
def aot_fn(x):
a = aot_custom_op1(x)
b = a.sin()
return aot_custom_op2(b)
for seed in range(10):
torch.cuda.manual_seed(seed)
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
x_clone = x.clone().detach().requires_grad_(True)
torch.cuda.manual_seed(seed)
ref = fn(x)
ref.sum().backward()
torch.cuda.manual_seed(seed)
res = aot_fn(x_clone)
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_set_get_rng_state(self, dtype, device):
def fn(x):
a = torch.rand_like(x) * x
state = torch.cuda.get_rng_state()
a = torch.rand_like(x) * a
torch.cuda.set_rng_state(state)
a = torch.rand_like(x) * a
return a
x = torch.rand(10, device=device, dtype=dtype)
for seed in range(10):
torch.cuda.manual_seed(seed)
ref = fn(x)
torch.cuda.manual_seed(seed)
fwd_compiler = functools.partial(count_philox_rand, freq=3)
aot_fn = aot_function(fn, fwd_compiler)
res = aot_fn(x)
self.assertEqual(ref, res)
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_min_cut_partitioner(self, dtype, device):
# Checks that the calling convention is maintained
shape = (16, 16)
def fn(x):
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
a = torch.sin(a)
a = torch.sin(a)
a = torch.sin(a)
return a
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
x_clone = x.clone().detach().requires_grad_(True)
torch.cuda.manual_seed(123)
ref = fn(x)
ref.sum().backward()
torch.cuda.manual_seed(123)
fwd_compiler = functools.partial(count_philox_rand, freq=2)
bwd_compiler = functools.partial(count_philox_rand, freq=0)
aot_custom = aot_function(fn, fwd_compiler, bwd_compiler, partition_fn=min_cut_rematerialization_partition)
# aot_custom = aot_function(fn, fwd_compiler, bwd_compiler)
res = aot_custom(x_clone)
res.sum().backward()
self.assertEqual(ref, res)
self.assertEqual(x.grad, x_clone.grad)
# TODO - Dropout needs more work because of offset calculation
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
@dtypes(torch.float32)
def test_checkpoint(self, dtype, device):
def g(x, y):
return torch.nn.functional.dropout(x, 0.6)
def fn(x, y):
return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False)
# x = torch.rand(2, 2, device="cuda", requires_grad=True)
x = torch.ones(2, 2, device="cuda", requires_grad=True)
y = torch.rand(2, 2, device="cuda", requires_grad=True)
torch.cuda.manual_seed(123)
ref = fn(x, y)
# With checkpointing we should recompute dropout in bwd, and should see philox_rand
fwd_compiler = functools.partial(count_philox_rand, freq=1)
bwd_compiler = functools.partial(count_philox_rand, freq=1)
aot_fn = aot_function(fn, fwd_compiler, bwd_compiler)
# We cant check accuracy here because rand_like generated different rand numbers than dropout
res = aot_fn(x, y)
res.sum().backward()
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_dropout_decomp(self, dtype, device):
def fn(x):
return torch.nn.functional.dropout(x, 0.6) * x
x = torch.rand(10, device=device, dtype=dtype)
# Ensure the decomp is happening
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1))
# We cant check accuracy here because rand_like generated different rand numbers than dropout
aot_fn(x)
only_for = ("cuda",)
instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for)
class NegativeTest(TestCase):
@dtypes(torch.float32)
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
def test_on_cpu(self, dtype, device):
def fn(x):
a = torch.rand_like(x) * x
a = torch.rand_like(x) * a
return a
x = torch.rand(10, device=device, dtype=dtype)
aot_fn = aot_function(fn, nop)
with self.assertRaises(RuntimeError):
aot_fn(x)
only_for = ("cpu",)
instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for)
if __name__ == "__main__":
run_tests()