forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NaiveConvolutionTranspose2d.cu
976 lines (859 loc) · 28.4 KB
/
NaiveConvolutionTranspose2d.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/cuda/im2col.cuh>
namespace at {
namespace native {
namespace {
static inline void slow_conv_transpose2d_shape_check(
const Tensor& input,
const Tensor& grad_output,
const Tensor& weight,
const Tensor& bias,
int kernel_height,
int kernel_width,
int stride_height,
int stride_width,
int pad_height,
int pad_width,
int output_padding_height,
int output_padding_width,
int dilation_height,
int dilation_width,
bool weight_nullable) {
TORCH_CHECK(
kernel_width > 0 && kernel_height > 0,
"kernel size should be greater than zero, but got kernel_height: ",
kernel_height,
" kernel_width: ",
kernel_width);
TORCH_CHECK(
stride_width > 0 && stride_height > 0,
"stride should be greater than zero, but got stride_height: ",
stride_height,
" stride_width: ",
stride_width);
TORCH_CHECK(
dilation_width > 0 && dilation_height > 0,
"dilation should be greater than zero, but got dilation_height: ",
dilation_height,
", dilation_width: ",
dilation_width);
TORCH_CHECK(
(output_padding_width < stride_width ||
output_padding_width < dilation_width) &&
(output_padding_height < stride_height ||
output_padding_height < dilation_height),
"output padding must be smaller than either stride or dilation, ",
"but got output_padding_height: ",
output_padding_height,
" output_padding_width: ",
output_padding_width,
" stride_height: ",
stride_height,
" stride_width: ",
stride_width,
" dilation_height: ",
dilation_height,
" dilation_width: ",
dilation_width);
if (weight.defined()) {
TORCH_CHECK(
weight.numel() != 0 && (weight.dim() == 2 || weight.dim() == 4),
"non-empty 2D or 4D weight tensor expected, but got: ",
weight.sizes());
if (bias.defined()) {
check_dim_size(bias, 1, 0, weight.size(1));
}
} else if (!weight_nullable) {
AT_ERROR("weight tensor is expected to be non-nullable");
}
int ndim = input.dim();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4) {
dimf++;
dimh++;
dimw++;
}
TORCH_CHECK(
input.numel() != 0 && (ndim == 3 || ndim == 4),
"non-empty 3D or 4D input tensor expected but got a tensor with size ",
input.sizes());
int64_t input_height = input.size(dimh);
int64_t input_width = input.size(dimw);
int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
(dilation_height * (kernel_height - 1) + 1) + output_padding_height;
int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
(dilation_width * (kernel_width - 1) + 1) + output_padding_width;
if (output_width < 1 || output_height < 1) {
AT_ERROR(
"Given input size per channel: (",
input_height,
" x ",
input_width,
"). Calculated output spatial size per channel: (",
output_height,
" x ",
output_width,
"). Output size is too small");
}
if (weight.defined()) {
int64_t n_input_plane = weight.size(0);
check_dim_size(input, ndim, dimf, n_input_plane);
}
if (grad_output.defined()) {
if (weight.defined()) {
int64_t n_output_plane = weight.size(1);
check_dim_size(grad_output, ndim, dimf, n_output_plane);
} else if (bias.defined()) {
int64_t n_output_plane = bias.size(0);
check_dim_size(grad_output, ndim, dimf, n_output_plane);
}
check_dim_size(grad_output, ndim, dimh, output_height);
check_dim_size(grad_output, ndim, dimw, output_width);
}
}
void slow_conv_transpose2d_out_cuda_template(
Tensor& output,
const Tensor& input_,
const Tensor& weight_,
IntArrayRef kernel_size,
const Tensor& bias_,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef dilation,
Tensor& columns_,
Tensor& ones_) {
TORCH_CHECK(
kernel_size.size() == 2,
"It is expected kernel_size equals to 2, but got size ",
kernel_size.size());
TORCH_CHECK(
dilation.size() == 2,
"It is expected dilation equals to 2, but got size ",
dilation.size());
TORCH_CHECK(
padding.size() == 2,
"It is expected padding equals to 2, but got size ",
padding.size());
TORCH_CHECK(
stride.size() == 2,
"It is expected stride equals to 2, but got size ",
stride.size());
TORCH_CHECK(
output_padding.size() == 2,
"It is expected stride equals to 2, but got size ",
output_padding.size());
TensorArg input_arg{input_, "input", 1}, output_arg{output, "output", 2},
weight_arg{weight_, "weight", 3}, bias_arg{bias_, "bias", 4},
columns_arg{columns_, "columns", 5}, ones_arg{ones_, "ones", 6};
checkAllSameGPU(
__func__,
{input_arg, output_arg, weight_arg, bias_arg, columns_arg, ones_arg});
int n_input_plane = weight_.size(0);
int n_output_plane = weight_.size(1);
Tensor columns = columns_;
Tensor ones = ones_;
int64_t kernel_height = kernel_size[0];
int64_t kernel_width = kernel_size[1];
int64_t dilation_height = dilation[0];
int64_t dilation_width = dilation[1];
int64_t pad_height = padding[0];
int64_t pad_width = padding[1];
int64_t stride_height = stride[0];
int64_t stride_width = stride[1];
int64_t output_padding_height = output_padding[0];
int64_t output_padding_width = output_padding[1];
slow_conv_transpose2d_shape_check(
input_,
Tensor(),
weight_,
bias_,
kernel_height,
kernel_width,
stride_height,
stride_width,
pad_height,
pad_width,
output_padding_height,
output_padding_width,
dilation_height,
dilation_width,
false);
Tensor input = input_.contiguous();
Tensor weight = weight_.contiguous();
Tensor bias = Tensor();
if (bias_.defined()) {
bias = bias_.contiguous();
TORCH_CHECK(ones.is_contiguous(), "ones needs to be contiguous");
}
bool is_batch = false;
if (input.dim() == 3) {
// Force batch
is_batch = true;
input.resize_({1, input.size(0), input.size(1), input.size(2)});
}
int64_t input_height = input.size(2);
int64_t input_width = input.size(3);
int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
(dilation_height * (kernel_height - 1) + 1) + output_padding_height;
int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
(dilation_width * (kernel_width - 1) + 1) + output_padding_width;
// Batch size + input planes
int64_t batch_size = input.size(0);
// Resize output
output.resize_({batch_size, n_output_plane, output_height, output_width});
// Resize temporary columns
columns.resize_({n_output_plane * kernel_width * kernel_height,
input_height * input_width});
// Define a buffer of ones, for bias accumulation
// Note: this buffer can be shared with other modules, it only ever gets
// increased, and always contains ones.
if (ones.dim() != 2 ||
ones.size(0) * ones.size(1) < output_height * output_width) {
// Resize plane and fill with ones...
ones.resize_({output_height, output_width});
ones.fill_(1);
}
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_transpose2d_out_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
// Helpers
Tensor input_n;
Tensor output_n;
// For each elt in batch, do:
for (int elt = 0; elt < batch_size; elt++) {
// Matrix mulitply per output:
input_n = input.select(0, elt);
output_n = output.select(0, elt);
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
int64_t m = weight.size(1) * weight.size(2) * weight.size(3);
int64_t n = columns.size(1);
int64_t k = weight.size(0);
// Do GEMM (note: this is a bit confusing because gemm assumes
// column-major matrices)
at::cuda::blas::gemm<scalar_t>(
'n',
't',
n,
m,
k,
1,
input_n.data_ptr<scalar_t>(),
n,
weight.data_ptr<scalar_t>(),
m,
0,
columns.data_ptr<scalar_t>(),
n);
// Unpack columns back into input:
col2im<scalar_t, accscalar_t>(
at::cuda::getCurrentCUDAStream(),
columns.data_ptr<scalar_t>(),
n_output_plane,
output_height,
output_width,
input_height,
input_width,
kernel_height,
kernel_width,
pad_height,
pad_width,
stride_height,
stride_width,
dilation_height,
dilation_width,
output_n.data_ptr<scalar_t>());
// Do Bias after:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
int64_t m_ = n_output_plane;
int64_t n_ = output_height * output_width;
int64_t k_ = 1;
// Do GEMM (note: this is a bit confusing because gemm assumes
// column-major matrices)
if (bias.defined()) {
at::cuda::blas::gemm<scalar_t>(
't',
'n',
n_,
m_,
k_,
1,
ones.data_ptr<scalar_t>(),
k_,
bias.data_ptr<scalar_t>(),
k_,
1,
output_n.data_ptr<scalar_t>(),
n_);
}
}
// Resize output
if (is_batch) {
output.resize_({n_output_plane, output_height, output_width});
input.resize_({n_input_plane, input_height, input_width});
}
}); // end of dispatch
}
static void slow_conv_transpose2d_backward_out_cuda_template(
const Tensor& input_,
const Tensor& grad_output_,
Tensor& grad_input,
const Tensor& weight_,
const Tensor& grad_columns_,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef dilation) {
TORCH_CHECK(
kernel_size.size() == 2,
"It is expected kernel_size equals to 2, but got size ",
kernel_size.size());
TORCH_CHECK(
dilation.size() == 2,
"It is expected dilation equals to 2, but got size ",
dilation.size());
TORCH_CHECK(
padding.size() == 2,
"It is expected padding equals to 2, but got size ",
padding.size());
TORCH_CHECK(
stride.size() == 2,
"It is expected stride equals to 2, but got size ",
stride.size());
TORCH_CHECK(
output_padding.size() == 2,
"It is expected stride equals to 2, but got size ",
output_padding.size());
TensorArg input_arg{input_, "input", 1},
grad_output_arg{grad_output_, "grad_output", 2},
weight_arg{weight_, "weight", 3},
grad_columns_arg{grad_columns_, "grad_columns", 4},
grad_input_arg{grad_input, "grad_input", 5};
checkAllSameGPU(
__func__,
{input_arg,
grad_output_arg,
weight_arg,
grad_columns_arg,
grad_input_arg});
int n_input_plane = weight_.size(0);
int n_output_plane = weight_.size(1);
int64_t kernel_height = kernel_size[0];
int64_t kernel_width = kernel_size[1];
int64_t dilation_height = dilation[0];
int64_t dilation_width = dilation[1];
int64_t pad_height = padding[0];
int64_t pad_width = padding[1];
int64_t stride_height = stride[0];
int64_t stride_width = stride[1];
int64_t output_padding_height = output_padding[0];
int64_t output_padding_width = output_padding[1];
Tensor grad_columns = grad_columns_;
slow_conv_transpose2d_shape_check(
input_,
grad_output_,
weight_,
Tensor(),
kernel_height,
kernel_width,
stride_height,
stride_width,
pad_height,
pad_width,
output_padding_height,
output_padding_width,
dilation_height,
dilation_width,
false);
Tensor input = input_.contiguous();
Tensor grad_output = grad_output_.contiguous();
Tensor weight = weight_.contiguous();
bool is_batch = false;
if (input.dim() == 3) {
// Force batch
is_batch = true;
input.resize_({1, input.size(0), input.size(1), input.size(2)});
grad_output.resize_(
{1, grad_output.size(0), grad_output.size(1), grad_output.size(2)});
}
int64_t input_width = input.size(3);
int64_t input_height = input.size(2);
int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
(dilation_height * (kernel_height - 1) + 1) + output_padding_height;
int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
(dilation_width * (kernel_width - 1) + 1) + output_padding_width;
// Batch size + input planes
int64_t batch_size = input.size(0);
// Resize output
grad_input.resize_({batch_size, n_input_plane, input_height, input_width});
// Resize temporary columns
grad_columns.resize_({n_output_plane * kernel_width * kernel_height,
input_height * input_width});
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
grad_output.scalar_type(), "slow_conv_transpose2d_backward_out_cuda", [&] {
// Helpers
Tensor grad_input_n = Tensor();
Tensor grad_output_n = Tensor();
// For each elt in batch, do:
for (int elt = 0; elt < batch_size; elt++) {
// Matrix mulitply per sample:
grad_input_n = grad_input.select(0, elt);
grad_output_n = grad_output.select(0, elt);
if (kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
stride_width != 1 || pad_height != 0 || pad_width != 0 ||
dilation_height != 1 || dilation_width != 1) {
im2col<scalar_t>(
at::cuda::getCurrentCUDAStream(),
grad_output_n.data_ptr<scalar_t>(),
n_output_plane,
output_height,
output_width,
input_height,
input_width,
kernel_height,
kernel_width,
pad_height,
pad_width,
stride_height,
stride_width,
dilation_height,
dilation_width,
grad_columns.data_ptr<scalar_t>());
}
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
int64_t m = weight.size(0);
int64_t n = grad_columns.size(1);
int64_t k = weight.size(1) * weight.size(2) * weight.size(3);
// Do GEMM (note: this is a bit confusing because gemm assumes
// column-major matrices)
auto gemm_in_ptr =
(kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
stride_width != 1 || pad_height != 0 || pad_width != 0 ||
dilation_height != 1 || dilation_width != 1)
? grad_columns.data_ptr<scalar_t>()
: grad_output_n.data_ptr<scalar_t>();
at::cuda::blas::gemm<scalar_t>(
'n',
'n',
n,
m,
k,
1,
gemm_in_ptr,
n,
weight.data_ptr<scalar_t>(),
k,
0,
grad_input_n.data_ptr<scalar_t>(),
n);
}
// Resize output
if (is_batch) {
grad_output.resize_({n_output_plane, output_height, output_width});
input.resize_({n_input_plane, input_height, input_width});
grad_input.resize_({n_input_plane, input_height, input_width});
}
}); // end of dispatch
}
void slow_conv_transpose2d_acc_grad_parameters_cuda_template(
const Tensor& input_,
const Tensor& grad_output_,
Tensor& grad_weight,
Tensor& grad_bias,
const Tensor& columns_,
const Tensor& ones_,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef dilation,
int scale_) {
TORCH_CHECK(
kernel_size.size() == 2,
"It is expected kernel_size equals to 2, but got size ",
kernel_size.size());
TORCH_CHECK(
dilation.size() == 2,
"It is expected dilation equals to 2, but got size ",
dilation.size());
TORCH_CHECK(
padding.size() == 2,
"It is expected padding equals to 2, but got size ",
padding.size());
TORCH_CHECK(
stride.size() == 2,
"It is expected stride equals to 2, but got size ",
stride.size());
TORCH_CHECK(
output_padding.size() == 2,
"It is expected stride equals to 2, but got size ",
output_padding.size());
TensorArg input_arg{input_, "input", 1},
grad_output_arg{grad_output_, "grad_output", 2},
grad_weight_arg{grad_weight, "grad_weight", 3},
grad_bias_arg{grad_bias, "grad_bias", 4},
columns_arg{columns_, "columns", 5}, ones_arg{ones_, "ones", 6};
checkAllSameGPU(
__func__,
{input_arg,
grad_output_arg,
grad_weight_arg,
grad_bias_arg,
columns_arg,
ones_arg});
int64_t kernel_height = kernel_size[0];
int64_t kernel_width = kernel_size[1];
int64_t dilation_height = dilation[0];
int64_t dilation_width = dilation[1];
int64_t pad_height = padding[0];
int64_t pad_width = padding[1];
int64_t stride_height = stride[0];
int64_t stride_width = stride[1];
int64_t output_padding_height = output_padding[0];
int64_t output_padding_width = output_padding[1];
Tensor columns = columns_;
Tensor ones = ones_;
slow_conv_transpose2d_shape_check(
input_,
grad_output_,
grad_weight,
grad_bias,
kernel_height,
kernel_width,
stride_height,
stride_width,
pad_height,
pad_width,
output_padding_height,
output_padding_width,
dilation_height,
dilation_width,
true);
Tensor input = input_.contiguous();
Tensor grad_output = grad_output_.contiguous();
int64_t n_output_plane;
if (grad_weight.defined()) {
n_output_plane = grad_weight.size(1);
} else if (grad_bias.defined()) {
n_output_plane = grad_bias.size(0);
} else {
return;
}
if (grad_weight.defined()) {
TORCH_CHECK(
grad_weight.is_contiguous(), "grad_weight needs to be contiguous");
}
TORCH_CHECK(columns.is_contiguous(), "columns needs to be contiguous");
if (grad_bias.defined()) {
TORCH_CHECK(grad_bias.is_contiguous(), "grad_bias needs to be contiguous");
TORCH_CHECK(ones.is_contiguous(), "ones needs to be contiguous");
}
bool is_batch = false;
if (input.dim() == 3) {
// Force batch
is_batch = true;
input.resize_({1, input.size(0), input.size(1), input.size(2)});
grad_output.resize_(
{1, grad_output.size(0), grad_output.size(1), grad_output.size(2)});
}
int64_t input_width = input.size(3);
int64_t input_height = input.size(2);
int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
(dilation_height * (kernel_height - 1) + 1) + output_padding_height;
int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
(dilation_width * (kernel_width - 1) + 1) + output_padding_width;
// Batch size + input planes
int64_t batch_size = input.size(0);
// Define a buffer of ones, for bias accumulation
if (ones.dim() != 2 ||
ones.size(0) * ones.size(1) < output_height * output_width) {
// Resize plane and fill with ones...
ones.resize_({output_height, output_width});
ones.fill_(1); // or static_cast<scalar_t>(1)
}
// Resize temporary columns
columns.resize_({n_output_plane * kernel_width * kernel_height,
input_height * input_width});
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_transpose2d_acc_grad_parameters_cuda", [&] {
// Helpers
Tensor input_n = Tensor();
Tensor grad_output_n = Tensor();
scalar_t scale = static_cast<scalar_t>(scale_);
// For each elt in batch, do:
for (int elt = 0; elt < batch_size; elt++) {
// Matrix mulitply per output:
grad_output_n = grad_output.select(0, elt);
// Do Weight:
if (grad_weight.defined()) {
// Matrix mulitply per output:
input_n = input.select(0, elt);
if (kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
stride_width != 1 || pad_height != 0 || pad_width != 0 ||
dilation_height != 1 || dilation_width != 1) {
// Extract columns:
im2col<scalar_t>(
at::cuda::getCurrentCUDAStream(),
grad_output_n.data_ptr<scalar_t>(),
n_output_plane,
output_height,
output_width,
input_height,
input_width,
kernel_height,
kernel_width,
pad_height,
pad_width,
stride_height,
stride_width,
dilation_height,
dilation_width,
columns.data_ptr<scalar_t>());
}
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
int64_t n = columns.size(0); // n_output_plane * kh * kw
int64_t m = input_n.size(0); // n_input_plane
int64_t k = columns.size(1); // input_height * input_width
// Do GEMM (note: this is a bit confusing because gemm assumes
// column-major matrices)
auto gemm_in_ptr =
(kernel_height != 1 || kernel_width != 1 ||
stride_height != 1 || stride_width != 1 || pad_height != 0 ||
pad_width != 0 || dilation_height != 1 || dilation_width != 1)
? columns.data_ptr<scalar_t>()
: grad_output_n.data_ptr<scalar_t>();
at::cuda::blas::gemm<scalar_t>(
't',
'n',
n,
m,
k,
scale,
gemm_in_ptr,
k,
input_n.data_ptr<scalar_t>(),
k,
1,
grad_weight.data_ptr<scalar_t>(),
n);
}
// Do Bias:
if (grad_bias.defined()) {
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
int64_t m_ = n_output_plane;
int64_t k_ = output_height * output_width;
// Do GEMV (note: this is a bit confusing because gemv assumes
// column-major matrices)
at::cuda::blas::gemv<scalar_t>(
't',
k_,
m_,
scale,
grad_output_n.data_ptr<scalar_t>(),
k_,
ones.data_ptr<scalar_t>(),
1,
1,
grad_bias.data_ptr<scalar_t>(),
1);
}
}
// Resize
if (is_batch) {
grad_output.resize_({n_output_plane, output_height, output_width});
input.resize_({input.size(1), input_height, input_width});
}
}); // end of dispatch
}
} // namespace
Tensor& slow_conv_transpose2d_out_cuda(const Tensor& input,
const Tensor& weight,
IntArrayRef kernel_size, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef dilation,
Tensor& output) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
Tensor columns = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor ones = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
slow_conv_transpose2d_out_cuda_template(
output,
input,
weight,
kernel_size,
bias,
stride,
padding,
output_padding,
dilation,
columns,
ones);
return output;
}
Tensor slow_conv_transpose2d_cuda(
const Tensor& input,
const Tensor& weight,
IntArrayRef kernel_size, const c10::optional<Tensor>& bias_opt,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef dilation) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor columns = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor ones = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
slow_conv_transpose2d_out_cuda_template(
output,
input,
weight,
kernel_size,
bias,
stride,
padding,
output_padding,
dilation,
columns,
ones);
return output;
}
std::tuple<Tensor&, Tensor&, Tensor&> slow_conv_transpose2d_backward_out_cuda(const Tensor& grad_output,
const Tensor& input,
const Tensor& weight,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef dilation,
const Tensor& columns,
const Tensor& ones,
Tensor& grad_input,
Tensor& grad_weight,
Tensor& grad_bias) {
if (grad_input.defined()) {
slow_conv_transpose2d_backward_out_cuda_template(
input,
grad_output,
grad_input,
weight,
columns,
kernel_size,
stride,
padding,
output_padding,
dilation);
}
if (grad_weight.defined()) {
grad_weight.resize_(weight.sizes());
grad_weight.zero_();
}
if (grad_bias.defined()) {
grad_bias.resize_({weight.size(1)});
grad_bias.zero_();
}
if (grad_weight.defined() || grad_bias.defined()) {
slow_conv_transpose2d_acc_grad_parameters_cuda_template(
input,
grad_output,
grad_weight,
grad_bias,
columns,
ones,
kernel_size,
stride,
padding,
output_padding,
dilation,
1);
}
return std::tuple<Tensor&, Tensor&, Tensor&>(
grad_input, grad_weight, grad_bias);
}
std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose2d_backward_cuda(
const Tensor& grad_output,
const Tensor& input,
const Tensor& weight,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef output_padding,
IntArrayRef dilation,
const Tensor& columns,
const Tensor& ones,
std::array<bool, 3> output_mask) {
Tensor grad_input;
Tensor grad_weight;
Tensor grad_bias;
if (output_mask[0]) {
grad_input = at::empty({0}, grad_output.options());
} else {
grad_input = Tensor();
}
if (output_mask[1]) {
grad_weight = at::empty({0}, grad_output.options());
} else {
grad_weight = Tensor();
}
if (output_mask[2]) {
grad_bias = at::empty({0}, grad_output.options());
} else {
grad_bias = Tensor();
}
if (grad_input.defined()) {
slow_conv_transpose2d_backward_out_cuda_template(
input,
grad_output,
grad_input,
weight,
columns,
kernel_size,
stride,
padding,
output_padding,
dilation);
}
if (grad_weight.defined()) {
grad_weight.resize_(weight.sizes());
grad_weight.zero_();
}
if (grad_bias.defined()) {
grad_bias.resize_({weight.size(1)});
grad_bias.zero_();
}
if (grad_weight.defined() || grad_bias.defined()) {
slow_conv_transpose2d_acc_grad_parameters_cuda_template(
input,
grad_output,
grad_weight,
grad_bias,
columns,
ones,
kernel_size,
stride,
padding,
output_padding,
dilation,
1);
}
return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
}
} // namespace native
} // namespace at