forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_dynamic_shapes.py
2213 lines (1996 loc) · 78.9 KB
/
test_dynamic_shapes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Owner(s): ["oncall: jit"]
import contextlib
import copy
import itertools
import inspect
import math
import operator
import re
import sympy
import torch
import torch.fx
import torch.nn.functional as F
from torch import sym_int, SymBool, SymFloat, SymInt
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental import sym_node
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.sym_node import to_node, SymNode, method_to_operator
from torch.fx.experimental.symbolic_shapes import (
DimConstraints,
DimDynamic,
expect_true,
guard_bool,
guard_float,
guard_int,
GuardOnDataDependentSymNode,
ShapeEnv,
is_symbolic,
StatelessSymbolicContext,
statically_known_true,
_constrain_range_for_size,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree
from torch.utils._sympy.functions import FloorDiv, Mod
aten = torch.ops.aten
meta_funcs = {}
def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
pytree.tree_map_(add_func, op)
return f
return decorator
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.shape)
@register_meta(aten.cat.default)
def cat_meta(tensors, dim=0):
concat_length = 0
shape = tensors[0].shape
for tensor in tensors:
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
else:
assert length == common_length
new_shape = list(shape)
new_shape[dim] = concat_length
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.default])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
shape.append(x)
return a.new_empty(tuple(shape))
@register_meta([aten.expand.default])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0):
# TODO: this is wrong in general
sym_stride = create_contiguous(sym_shape)
r = torch.Tensor._make_wrapper_subclass(
cls, sym_shape,
sym_stride, storage_offset,
dtype=dtype, layout=layout, requires_grad=requires_grad,
device=device,
)
return r
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
if func_overload in meta_funcs:
return meta_funcs[func_overload](*args, **kwargs)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env):
from torch._dynamo.source import ConstantSource
constraint_dims = [None] * arg.dim()
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
sym_shapes, sym_strides, sym_storage_offset = \
shape_env.create_symbolic_sizes_strides_storage_offset(
arg,
source=ConstantSource(name),
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_dims,
constraint_sizes=constraint_dims
),
)
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset)
def create_symtype(cls, pytype, shape_env, val, duck=True):
from torch._dynamo.source import ConstantSource
symbol = shape_env.create_symbol(
val,
source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC,
constraint_dim=None,
)
return cls(SymNode(
symbol,
shape_env,
pytype,
hint=val,
))
# TODO: default duck to False
def create_symint(shape_env, i: int, duck=True):
return create_symtype(SymInt, int, shape_env, i, duck=duck)
def create_symbool(shape_env, b: bool):
return create_symtype(SymBool, bool, shape_env, b)
def create_symfloat(shape_env, f: float):
return create_symtype(SymFloat, float, shape_env, f)
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestPySymInt(TestCase):
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
for op in ops:
for args in itertools.permutations(symints, 2):
if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertTrue(5 // a == 5 // 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], SymNode))
self.assertTrue(isinstance(x.shape[0], SymInt))
self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4)
self.assertTrue(x.shape[2], 3)
self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4)
# Should be simplifiable to an integer.
# Ref: https://github.com/pytorch/pytorch/pull/107492
self.assertTrue(isinstance(x.size()[1], SymInt))
self.assertTrue(isinstance(x.size()[1].node.maybe_as_int(), int)) # due to guard above
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), SymInt))
self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int))
y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
self.assertTrue(isinstance(y.storage_offset(), SymInt))
self.assertTrue(y.storage_offset() == 12)
def test_binary(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# broadcasting
y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == y.shape[2])
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertTrue(z.shape[2] == 2)
def test_symint_vargs(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# shape list
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints in a list
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(5, y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python ints and symints in a list
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
z = y.expand((y.shape[1],))
z = y.expand(y.shape[1])
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], SymInt)
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op, _bt = shape_env.guards[-1]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
def test_numel(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
self.assertIsInstance(x.numel(), torch.SymInt)
self.assertIsInstance(torch.numel(x), torch.SymInt)
x = torch.rand(3, 3)
self.assertIsInstance(x.numel(), int)
self.assertIsInstance(torch.numel(x), int)
def test_int_to_float(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = torch.sym_float(x.shape[0])
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
def forward(self, x):
bs, c, h, w = x.shape
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
m = CustomModule()
x = torch.rand(1, 3, 4, 4)
# should not TypeError: pad(): argument 'pad' (position 2) must be
# tuple of ints, not tuple
torch.fx.symbolic_trace(m)
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta')
self.assertIsInstance(r.shape[0], SymInt)
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_sym_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = sym_int(a0)
self.assertEqual(r, 5)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""")
a1 = create_symint(shape_env, 7)
r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(floor(s1/2), 3)""")
a3 = create_symint(shape_env, 3)
r = sym_int(2.0 * torch.sym_float(a3))
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(2*s2, 6)""")
def test_sym_sqrt(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 4)
r = torch._sym_sqrt(a0)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s0), 2)""")
def test_sym_floor(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.floor(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""")
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
def test_sym_ceil(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.ceil(a0 / 2)
self.assertEqual(r, 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""")
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
def test_sym_ite(self):
shape_env = ShapeEnv()
t = create_symint(shape_env, 5)
f = create_symint(shape_env, 4)
b1 = True
r1 = torch.sym_ite(b1, t, f)
self.assertTrue(r1 is t)
b2 = False
r2 = torch.sym_ite(b2, t, f)
self.assertTrue(r2 is f)
b3 = t == 5
r3 = torch.sym_ite(b3, t, f)
self.assertEqual(len(shape_env.guards), 0)
self.assertEqual(r3, 5)
self.assertEqual(type(t), type(r3))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""")
b4 = f == 5
r4 = torch.sym_ite(b4, t, f)
self.assertEqual(len(shape_env.guards), 1)
self.assertEqual(r4, 4)
self.assertEqual(type(f), type(r4))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""")
def test_tracing_sym_ite(self):
def f(x):
b = x.shape[0] == 5
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
return ret
gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
self.assertEqual(len(gm.shape_env.guards), 0)
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 5
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None
return sym_ite""")
r1 = gm(torch.ones(4, 5))
self.assertIsInstance(r1, int)
self.assertEqual(r1, 5)
r2 = gm(torch.ones(5, 4))
self.assertIsInstance(r2, int)
self.assertEqual(r2, 5)
def test_int_conversion(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
int(a0)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_data_dependent_guard(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
def test_expect_true_basic(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i0_sym = i0.node.expr
# This doesn't error
self.assertTrue(expect_true(i0 == 0))
# This generates a deferred runtime assert via replacement
self.assertEqual(shape_env.replacements[i0_sym], 0)
# After expecting true, guards now resolve given the runtime assert
bool(i0 == 0)
def test_expect_true_with_s0(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 5)
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(i0 <= s0))
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[-s0 + u0 <= 0]"""
)
self.assertTrue(i0 <= s0)
self.assertFalse(i0 > s0)
def test_expect_true_prefer_later(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i1 = shape_env.create_unbacked_symint()
i1_sym = i1.node.expr
self.assertTrue(expect_true(i0 + i1 == 10))
# Importantly, this is put in i1, not i0!
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]),
"""[Eq(u0 + u1, 10)]"""
)
self.assertTrue(i0 + i1 == 10)
# NB: We currently don't support deriving that we can substitute
# i0 + i1 with 10; maybe we should, but this means our rewriting
# system is no longer confluent (it's probably OK though, because
# you're unlikely to get other equalities like this on the
# unbacked SymInts.)
def test_unbacked_substitution(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i1 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i0)
_constrain_range_for_size(i1)
self.assertTrue(expect_true(i0 == i1 * 4))
self.assertExpectedInline(str(i0), """4*u1""")
i2 = shape_env.create_unbacked_symint()
i3 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i2)
_constrain_range_for_size(i3)
self.assertTrue(expect_true(i2 * 4 == i3))
self.assertExpectedInline(str(i3), """4*u2""")
def test_avoid_unbacked_substitution(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i0)
i1 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i1)
self.assertTrue(expect_true(i0 == 10 - i1))
self.assertExpectedInline(str(i0), """u0""")
def test_expect_true_double_digits(self):
shape_env = ShapeEnv()
ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10
self.assertEqual(str(ia[-1]), "u10")
self.assertTrue(expect_true(sum(ia) == 20))
self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1)
def test_expect_true_refine_range(self):
shape_env = ShapeEnv()
for i, rel in enumerate([lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]):
with self.subTest(f"i = {i}"):
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(rel(i0)))
self.assertTrue(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertFalse(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 > 4))
self.assertTrue(statically_known_true(i0 >= 5))
for i, rel in enumerate([lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]):
with self.subTest(f"i = {i}"):
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(rel(i0)))
self.assertFalse(statically_known_true(i0 != 2))
self.assertFalse(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 < 4))
self.assertTrue(statically_known_true(i0 <= 5))
def test_guard_refine_range(self):
shape_env = ShapeEnv()
for i, rel in enumerate([lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 10, duck=False)
self.assertTrue(bool(rel(i0)))
self.assertTrue(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertFalse(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 > 4))
self.assertTrue(statically_known_true(i0 >= 5))
for i, rel in enumerate([lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 2, duck=False)
self.assertFalse(bool(rel(i0)))
self.assertFalse(statically_known_true(i0 != 3))
self.assertFalse(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 <= 4))
self.assertTrue(statically_known_true(i0 < 5))
for i, rel in enumerate([lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 2, duck=False)
self.assertTrue(bool(rel(i0)))
self.assertFalse(statically_known_true(i0 != 2))
self.assertFalse(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 < 4))
self.assertTrue(statically_known_true(i0 <= 3))
for i, rel in enumerate([lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 10, duck=False)
self.assertFalse(bool(rel(i0)))
self.assertTrue(statically_known_true(i0 != 2))
self.assertTrue(statically_known_true(i0 != 3))
self.assertFalse(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 >= 4))
self.assertTrue(statically_known_true(i0 > 3))
def test_non_overlapping_and_dense(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = torch.empty_strided((a0, 7), (1, a0), device='meta')
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
def test_specialize_zero_one(self):
shape_env = ShapeEnv(specialize_zero_one=True)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(specialize_zero_one=False)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 1)
def test_duck_shape(self):
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(duck_shape=False)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 1)
def test_int_bool(self):
# See https://github.com/pytorch/pytorch/issues/95981
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
assert a0
self.assertEqual(len(shape_env.guards), 0)
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
sym_int_encountered = False
class TestSymInt(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
# WARNING: do not do identity tests on the outer
# SymInt/SymFloat, they are NOT STABLE
sym_int_encountered = kwargs["alpha"].node is a0.node
kwargs["alpha"] = 0
return func(*args)
x = torch.rand([4, 4])
with TestSymInt():
y = torch.add(x, x, alpha=a0)
self.assertTrue(sym_int_encountered)
def test_deepcopy(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
assert a0 < 4
new_shape_env = copy.deepcopy(shape_env)
self.assertEqual(len(new_shape_env.guards), 1)
def test_print_readable_with_symints(self):
def f(a, b):
dim0 = a.shape[0] + b.shape[0]
dim1 = a.shape[1] + b.shape[1]
d = a.new_empty(dim0, dim1)
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
out = fx_g.print_readable(print_output=False)
self.assertExpectedInline(out.strip(), """\
class f(torch.nn.Module):
def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"):
# No stacktrace found for following nodes
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0)
add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0]
getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""") # noqa: B950
def test_statically_known_true(self):
shape_env = ShapeEnv()
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
# Statically known true
self.assertTrue(statically_known_true(True))
self.assertTrue(statically_known_true(s2 == s2))
self.assertTrue(statically_known_true(s2 * s3 > s3))
self.assertTrue(statically_known_true(s3 * s4 > s4))
self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
# Statically known false
self.assertFalse(statically_known_true(False))
self.assertFalse(statically_known_true(s3 * s4 <= s4))
self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
# True for hints, but not known statically
self.assertFalse(statically_known_true(s2 + s2 == s4))
self.assertFalse(statically_known_true(s4 % s2 == 0))
self.assertFalse(statically_known_true(s2 != s3))
self.assertFalse(statically_known_true(s3 * s4 > s2))
# False for hints, but not known statically
self.assertFalse(statically_known_true(s2 == s3))
self.assertFalse(statically_known_true(s2 > s3))
self.assertFalse(statically_known_true(s3 + s3 == s4))
# No guards should be generated
self.assertEqual(len(shape_env.guards), 0)
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestSymNumberMagicMethods(TestCase):
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
# Helper function
# NB: don't use one as that will get specialized
seed_node = (create_symint(shape_env, 2) / 2.).node
bool_seed_node = (create_symint(shape_env, 2) == 2).node
def get_sym_inp(inp):
# NB: this must come before int
if isinstance(inp, bool):
return torch.SymBool(to_node(bool_seed_node, inp))
elif isinstance(inp, int):
return torch.SymInt(to_node(seed_node, inp))
else:
return torch.SymFloat(to_node(seed_node, inp))
def maybe_xfail(inp1, inp2):
if fn == "sym_sqrt" and inp1 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:
# ZeroDivisionError: division by zero
return self.assertRaises((ZeroDivisionError,))
elif fn == "pow" and inp1 == 0 and inp2 < 0:
# ZeroDivisionError: 0.0 cannot be raised to a negative power
return self.assertRaises((ZeroDivisionError,))
elif fn == "pow" and inp1 < 0 and inp2 in (2.5, -2.5) and (
type(inp1) in (SymFloat, SymInt) or
type(inp2) in (SymFloat, SymInt)
):
# Complex result, which we do not support:
# TypeError: Cannot convert complex to float
return self.assertRaises((TypeError,))
elif fn in ("lshift", "rshift") and not (
isinstance(inp1, (SymInt, int)) and
isinstance(inp2, (SymInt, int))
):
# TypeError: unsupported operand type(s)
return self.assertRaises((TypeError,))
elif fn in ("lshift", "rshift") and inp2 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
else:
return contextlib.nullcontext()
lambda_apply = method_to_operator(fn)
def guard_fn(v):
if type(v) in (SymBool, bool):
return guard_bool(v)
elif type(v) in (SymFloat, float):
return guard_float(v)
else: # SymInt, int
return guard_int(v)
# Get reference result
with maybe_xfail(inp1, inp2):
if is_unary_fn:
ref_out = lambda_apply(inp1)
else:
ref_out = lambda_apply(inp1, inp2)
# Symified first arg
sym_inp1 = get_sym_inp(inp1)
with maybe_xfail(sym_inp1, inp2):
if is_unary_fn:
out = lambda_apply(sym_inp1)
else:
out = lambda_apply(sym_inp1, inp2)
if fn not in sym_node.alternate_impl_if_hinted_methods:
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
if is_unary_fn:
return
# Symified second arg
sym_inp2 = get_sym_inp(inp2)
with maybe_xfail(inp1, sym_inp2):
out = lambda_apply(inp1, sym_inp2)
if fn not in sym_node.alternate_impl_if_hinted_methods:
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
# Symified both args
with maybe_xfail(sym_inp1, sym_inp2):
out = lambda_apply(sym_inp1, sym_inp2)
if fn not in sym_node.alternate_impl_if_hinted_methods:
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
@parametrize("fn", list(sym_node.magic_methods.keys()))
def test_bool_method(self, fn):
# sym_ite has its own tests
if fn not in sym_node.bool_magic_methods or fn == "sym_ite":
self.skipTest(f"{fn} is non-bool")
is_unary_fn = fn in sym_node.unary_methods
shape_env = ShapeEnv()
self._do_test(fn, True, False, shape_env, is_unary_fn)
@parametrize("fn", list(sym_node.magic_methods.keys()))
@parametrize("first_type", ["int", "float"])
@parametrize("second_type", ["int", "float"])
def test_method(self, fn, first_type, second_type):
if first_type == "float":
# TODO: Hmm, this looks like we skip all floats
self.skipTest(f"{fn} is not a float magic method")
if (first_type == "int" or second_type == "int") and fn in sym_node.only_float_magic_methods:
self.skipTest(f"{fn} is not an int method")
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
# Second argument is ignored for unary function. So only run for one type
if is_unary_fn and second_type == "float":
self.skipTest(f"{fn} is unary and already tested")
if fn in sym_node.bool_magic_methods:
self.skipTest(f"{fn} is bool")
# Only floats here since these will be converted to int if necessary.
# We also ignore complex and bool.
values = (
0.0,
1.0,
0.5 if fn in ("sym_acos", "sym_asin") else 2.5 # avoid math domain error
)
neg_values = tuple(-x for x in values)
for inp1, inp2 in itertools.chain(
itertools.product(values, values),
itertools.product(values, neg_values),
itertools.product(neg_values, values),
itertools.product(neg_values, neg_values),
):
if first_type == "int":
inp1 = int(inp1)
if second_type == "int":
inp2 = int(inp2)
shape_env = ShapeEnv()
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
def get_constant_bool(self, val):
return SymBool(torch._C._get_constant_bool_symnode(val))
def test_symnode_hashing(self):
shape_env = ShapeEnv()
# SymInt, SymBool, SymFloat are unhashable
unhashable = (
create_symint(shape_env, 3),
create_symbool(shape_env, True),
# We should be passing in float here, but create_symbol currently
# only supports int
create_symfloat(shape_env, 3.0),
)
for x in unhashable:
with self.assertRaisesRegex(TypeError, "unhashable"):
hash(x)
# NestedInt (SymInt), constant SymBool, SymNode are hashable
j1 = torch._C._get_nested_int(1, 1)
j1_copy = torch._C._get_nested_int(1, 1)
j2 = torch._C._get_nested_int(2, 1)
t = self.get_constant_bool(True)
t_copy = self.get_constant_bool(True)
f = self.get_constant_bool(False)
n = create_symint(shape_env, 3).node
m = self.get_constant_bool(True).node
self.assertIs(j1 == j1_copy, True)
self.assertEqual(hash(j1), hash(j1_copy))
self.assertIs(j1 == j2, False)
self.assertNotEqual(hash(j1), hash(j2))
self.assertIs(t == t_copy, True)
self.assertEqual(hash(t), hash(t_copy))
self.assertIs(t == f, False)
self.assertNotEqual(hash(t), hash(f))
hash(n)
hash(m)
def test_non_symbolic_symnode(self):
j1 = torch._C._get_nested_int(1, 1)
j2 = torch._C._get_nested_int(1, 1)
j3 = torch._C._get_nested_int(3, 1)
self.assertIsInstance(j1, torch.SymInt)
self.assertNotIsInstance(j1, int)
with self.assertRaisesRegex(RuntimeError, "add not supported by NestedIntSymNode"):
j1 + 3
self.assertFalse(j1 == 3)
with self.assertRaisesRegex(RuntimeError, "indeterminate"):
self.assertFalse(3 >= j2)
self.assertIs(j1 == j1, True)
self.assertIs(j1 == j2, True)
self.assertIs(j1 == j3, False)
self.assertIs(j1 != j3, True)
self.assertIs(j1 != j2, False)
x = self.get_constant_bool(True)
#
# Unary
#
# op(constant SymBool)
self.assertIs(x.__sym_not__(), False)
#
# Binary
#
# op(constant SymBool, bool)
# op(constant SymBool, constant SymBool)
# op(bool, constant SymBool)
self.assertIs(operator.and_(x, True), True)
self.assertIs(operator.and_(x, x), True)
self.assertIs(operator.and_(True, x), True)
# op(symbolic SymBool, constant Symbool)
# op(constant SymBool, symbolic Symbool)
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
b = create_symint(shape_env, 2)
c = a == b # symbolic SymBool
d = self.get_constant_bool(True)
e = operator.and_(c, d)
f = operator.and_(d, c)
self.assertTrue(is_symbolic(e))
self.assertTrue(is_symbolic(f))
self.assertIs(e.node.guard_bool("", 0), True)
self.assertIs(f.node.guard_bool("", 0), True)
# Comparing sizes
sz1 = torch.Size([j1, j1, j1])
sz2 = torch.Size([j1, j1, j1])
self.assertIs(sz1 == sz2, True)