forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TensorIteratorInternal.h
64 lines (54 loc) · 1.82 KB
/
TensorIteratorInternal.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#pragma once
#include <ATen/native/TensorIterator.h>
#include <c10/util/SmallBuffer.h>
namespace at {
struct DimCounter {
DimCounter(IntArrayRef shape, Range range);
void increment(const std::array<int64_t, 2>& step);
bool is_done() const;
std::array<int64_t, 2> max_2d_step() const;
IntArrayRef shape;
Range range;
c10::SmallBuffer<int64_t, 4> values;
int64_t offset;
};
namespace internal {
inline void get_data_ptrs(
char** ptrs, ArrayRef<char*> base, IntArrayRef strides, IntArrayRef counter) {
const int64_t ntensors = base.size();
const int64_t ndim = counter.size();
std::copy(base.begin(), base.end(), ptrs);
for (int64_t dim = 0; dim < ndim; ++dim) {
int64_t value = counter[dim];
for (int64_t arg = 0; arg < ntensors; ++arg) {
ptrs[arg] += value * strides[dim * ntensors + arg];
}
}
}
inline void serial_for_each(
IntArrayRef shape, IntArrayRef strides,
char** base_ptrs, size_t ntensors,
typename TensorIteratorBase::loop2d_t loop, Range range) {
const auto ndim = shape.size();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
strides.size() == ntensors * std::max(size_t{2}, ndim));
if (ndim <= 1) {
if (range.begin == 0) {
loop(base_ptrs, strides.data(), range.size(), 1);
} else {
c10::SmallBuffer<char*, 4> ptrs(ntensors);
get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin});
loop(ptrs.data(), strides.data(), range.size(), 1);
}
} else {
c10::SmallBuffer<char*, 4> ptrs(ntensors);
auto counter = DimCounter(shape, range);
while (!counter.is_done()) {
get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, counter.values);
auto step = counter.max_2d_step();
loop(ptrs.data(), strides.data(), step[0], step[1]);
counter.increment(step);
}
}
}
}} // namespace at::internal