forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LegacyVmapTransforms.cpp
296 lines (264 loc) · 11.1 KB
/
LegacyVmapTransforms.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
#include <ATen/LegacyVmapTransforms.h>
#include <ATen/ATen.h>
#include <ATen/core/IListRef.h>
#include <c10/util/irange.h>
namespace at {
// Checks if the batch dims in `bdims` appear at the front of the tensor.
static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
for (const auto idx : c10::irange(static_cast<int64_t>(bdims.size()))) {
if (bdims[idx].dim() != idx) {
return false;
}
}
return true;
}
// Takes a BatchedTensorImpl, permutes all of the batch dims to the front,
// and then returns a physical version of the Tensor.
static Tensor permuteBatchDimsToFront(BatchedTensorImpl* batched) {
auto bdims = batched->bdims();
const Tensor& physical_tensor = batched->value();
if (areBdimsAtFrontInOrder(bdims)) {
return physical_tensor;
}
const auto sizes = physical_tensor.sizes();
VmapDimVector permutation(sizes.size(), 0);
permutation.reserve(sizes.size());
const auto is_bdim = createBatchDimBitset(bdims);
int64_t idx = 0;
for (const auto& bdim : bdims) {
permutation[idx++] = bdim.dim();
}
for (const auto ptr : c10::irange(sizes.size())) {
if (is_bdim[ptr]) {
continue;
}
permutation[idx++] = ptr;
}
return physical_tensor.permute(permutation);
}
VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logical_tensor) {
auto* batched = maybeGetBatchedImpl(logical_tensor);
TORCH_INTERNAL_ASSERT(
batched,
"logicalToPhysical(tensor) should only be passed a BatchedTensor");
return { permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims()) };
}
int64_t VmapPhysicalView::numBatchDims() const {
return levels_.count();
}
int64_t VmapPhysicalView::numLogicalDims() const {
return /*physical*/tensor_.dim() - numBatchDims();
}
VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const {
auto logical_ndim = numLogicalDims();
// NB: fmap doesn't have a SmallVector variant, so we don't use it here.
VmapDimVector result;
result.reserve(logical_ndim);
if (opt_logical_dims.has_value() && !opt_logical_dims.value().empty()) {
auto logical_dims = opt_logical_dims.value();
for (auto dim : logical_dims) {
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
}
} else {
for (int64_t dim = 0; dim < logical_ndim; dim++) {
result.push_back(dim + numBatchDims());
}
}
return result;
}
int64_t VmapPhysicalView::getPhysicalDim(int64_t logical_dim) const {
auto logical_ndim = numLogicalDims();
return maybe_wrap_dim(logical_dim, logical_ndim) + numBatchDims();
}
VmapDimVector VmapPhysicalView::getPhysicalShape(IntArrayRef logical_shape) const {
VmapDimVector result;
result.reserve(logical_shape.size() + numBatchDims());
auto tensor_sizes = tensor_.sizes();
result.insert(result.end(), tensor_sizes.begin(), tensor_sizes.begin() + numBatchDims());
result.insert(result.end(), logical_shape.begin(), logical_shape.end());
return result;
}
static BatchDims computeFrontBatchDimsFromLevels(std::bitset<kVmapNumLevels> levels_bitset) {
BatchDims bdims;
int64_t dim = 0;
for (const auto level : c10::irange(kVmapNumLevels)) {
if (!levels_bitset[level]) {
continue;
}
bdims.emplace_back(level, dim++);
}
return bdims;
}
// Given a Tensor or a BatchedTensor, returns the underlying physical tensor
// with all vmapped dimensions permuted to the front, if they exist, and a
// bitset of vmap levels that were present in the tensor.
static std::pair<Tensor,std::bitset<kVmapNumLevels>>
getPhysicalTensorAndLevels(const Tensor& self) {
auto* batched = maybeGetBatchedImpl(self);
if (batched) {
return {permuteBatchDimsToFront(batched), createVmapLevelsBitset(batched->bdims())};
}
return {self, 0};
}
// Given a Tensor or a BatchedTensor, creates a physical view of the tensor
// such that it has a batch dimension for each level in `requested_levels`
// and `requested_example_dim` number of non-batch-dimensions.
//
// This function is useful in preparing physical views on tensors that can
// then be passed into broadcasting operations. For example, when adding
// two BatchedTensors of sizes [B0, 3] and [B0, B1, 2, 3], where the Bi are the
// batch dimensions, we must align the batch dimensions and non-batch-dimensions
// (henceforth referred to as the "example" dimensions) separately to produce
// tensors of size [B0, 1, 1, 3] and [B0, B1, 2, 3] so that they can be added.
//
// Here's a direct example of using alignBatchDimsAtFront on the above two tensors.
//
// 1) alignBatchDimsAtFront([B0, 3], requested_levels={0, 1}, requested_example_dim=2)
// returns a physical view of size [B0, 1, 1, 3] by adding an extra dimension for
// level 1 and another extra dimension to pad the example dimensions to 2.
//
// 2) alignBatchDimsAtFront([B0, B1, 2, 3], requested_levels={0, 1}, requested_example_dim=2)
// returns a physical view of size [B0, B1, 2, 3]
static Tensor alignBatchDimsAtFront(
const Tensor& self,
std::bitset<kVmapNumLevels> requested_levels,
int64_t requested_example_dim) {
auto [physical_tensor, tensor_levels] = getPhysicalTensorAndLevels(self);
TORCH_INTERNAL_ASSERT(
(tensor_levels | requested_levels) == requested_levels,
"`requested_levels` must be a superset of `self`'s levels");
auto physical_sizes = physical_tensor.sizes();
const auto tensor_example_dim = (
static_cast<int64_t>(physical_sizes.size())
- /*num_batch_dims*/static_cast<int64_t>(tensor_levels.count())
);
TORCH_INTERNAL_ASSERT(tensor_example_dim <= requested_example_dim);
if (tensor_levels == requested_levels && tensor_example_dim == requested_example_dim) {
// Optimization: no need to do another view if the physical tensor is
// already the correct shape
return physical_tensor;
}
VmapDimVector aligned_sizes(requested_levels.count() + requested_example_dim, 1);
// align the example dims (non-bdims dims) first
// aligned_sizes[-tensor_example_dim:] = tensor_sizes[-tensor_example_dim:]
std::copy(
physical_sizes.rbegin(),
physical_sizes.rbegin() + tensor_example_dim,
aligned_sizes.rbegin());
// align the bdims
int64_t level = 0;
int64_t tensor_dim = 0;
for (const auto bdim : c10::irange(requested_levels.count())) {
// Determine the level of the bdim
while (!requested_levels[level]) level++;
if (tensor_levels[level]) {
aligned_sizes[bdim] = physical_sizes[tensor_dim++];
}
level++;
}
return physical_tensor.view(aligned_sizes);
}
// The algorithm is as follows:
// 1. Figure out what all of the collective levels in `logical_tensors` is.
// 2. Move all batch dims to the front of the tensors and add extra dims
// of size 1. At this point, every tensor will have a dimension for
// each of the collective levels.
// 3. Compute the batch_sizes.
// 4. Expand each physical tensor so that they have output batch size equal
// to `batch_sizes`
VmapPhysicalViewVec
MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
// Figure out all of the collective vmap levels in `logical_tensors`.
std::bitset<kVmapNumLevels> collective_levels;
for (const auto& logical_tensor : logical_tensors) {
auto* batched = maybeGetBatchedImpl(logical_tensor);
if (batched) {
collective_levels |= createVmapLevelsBitset(batched->bdims());
}
}
// Populate physical_tensors.
// This contains a list of regular (non-Batched) Tensors where all of the
// batch dims have been moved to the front of the tensor. Any previously
// non-existing batch dims get added to the tensors as new dimensions of size 1.
std::vector<Tensor> physical_tensors;
int64_t num_batch_dims = collective_levels.count();
for (const auto& logical_tensor : logical_tensors) {
auto requested_example_dim = /*logical_dim*/logical_tensor.dim();
auto physical_tensor = alignBatchDimsAtFront(
logical_tensor, collective_levels, requested_example_dim);
physical_tensors.push_back(std::move(physical_tensor));
}
// Compute batch_sizes
VmapDimVector batch_sizes(num_batch_dims, 1);
for (const auto& physical_tensor : physical_tensors) {
auto physical_sizes = physical_tensor.sizes();
for (const auto dim : c10::irange(num_batch_dims)) {
if (physical_sizes[dim] != 1) {
batch_sizes[dim] = physical_sizes[dim];
}
}
}
// Expand each physical_tensor so that it has batch sizes `batch_sizes`
VmapPhysicalViewVec result;
for (const auto& physical_tensor : physical_tensors) {
VmapDimVector expanded_size(batch_sizes.begin(), batch_sizes.end());
auto physical_sizes = physical_tensor.sizes();
expanded_size.insert(
expanded_size.end(),
physical_sizes.begin() + num_batch_dims,
physical_sizes.end());
result.emplace_back(physical_tensor.expand(expanded_size), collective_levels);
}
return result;
}
static std::pair<std::bitset<kVmapNumLevels>,int64_t>
getLevelsAndLargestLogicalDim(TensorList logical_tensors) {
TORCH_INTERNAL_ASSERT(!logical_tensors.empty());
std::bitset<kVmapNumLevels> levels;
int64_t largest_logical_dim = -1;
for (const auto& tensor : logical_tensors) {
auto* batched = maybeGetBatchedImpl(tensor);
if (batched) {
levels = levels | createVmapLevelsBitset(batched->bdims());
}
auto tensor_logical_dim = /*logical dim*/tensor.dim();
if (tensor_logical_dim > largest_logical_dim) {
largest_logical_dim = tensor_logical_dim;
}
}
return { levels, largest_logical_dim };
}
VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) {
TORCH_INTERNAL_ASSERT(
logical_tensors.size() == 2,
"This function has only been tested for two tensors. Please add more tests ",
"before removing this check ");
VmapPhysicalViewVec result;
auto [levels, largest_logical_dim] = getLevelsAndLargestLogicalDim(logical_tensors);
for (const auto& tensor : logical_tensors) {
// NB: It's possible that we didn't actually need to align `tensor`.
// For example, when adding two tensors of size (B, 2), and (3, 2), where
// the first Tensor is a BatchedTensor with batch dim B and the second is
// a regular Tensor, we will return views of size (B, 1, 2) and (1, 3, 2).
// However, the view on the second tensor is unnecessary: broadcasting
// semantics allow for the addition of two tensors of size (B, 1, 2) and (3, 2)!
//
// If this unnecessary view is a problem, consider optimizing it away in
// the future. This may involve creating a new type of VmapPhysicalView
auto aligned = alignBatchDimsAtFront(tensor, levels, largest_logical_dim) ;
result.emplace_back(std::move(aligned), levels);
}
return result;
}
VmapPhysicalToLogicalMap VmapPhysicalView::getPhysicalToLogicalMap() const {
return VmapPhysicalToLogicalMap(levels_);
}
Tensor VmapPhysicalToLogicalMap::apply(const Tensor& physical_tensor) const {
return makeBatched(physical_tensor, computeFrontBatchDimsFromLevels(levels_));
}
void VmapPhysicalToLogicalMap::applyInplace(std::vector<Tensor>& physical_tensors) const {
for (auto & physical_tensor : physical_tensors) {
physical_tensor = apply(physical_tensor);
}
}
} // namespace at