forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
operator_gradient.h
337 lines (301 loc) · 9.98 KB
/
operator_gradient.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
#ifndef CAFFE2_CORE_OPERATOR_GRADIENT_H_
#define CAFFE2_CORE_OPERATOR_GRADIENT_H_
#include "c10/util/Registry.h"
#include "caffe2/core/operator_schema.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
/* @brief A struct that abstracts on top of dense and sparse blobs.
*
* For a dense blob, its gradient name should be written into dense_, and for
* a sparse blob, its gradient name should be written into indice_ for
* the sparse indices and value_ for the values.
*/
struct TORCH_API GradientWrapper {
string dense_;
string indices_;
string values_;
inline bool IsDense() const {
return (dense_.size() != 0);
}
inline bool IsSparse() const {
return (indices_.size() != 0 || values_.size() != 0);
}
inline bool IsEmpty() const {
return (!IsDense() && !IsSparse());
}
};
/**
* A struct that holds the gradient operators and related gradient maps.
*/
struct TORCH_API GradientOpsMeta {
vector<OperatorDef> ops_;
vector<GradientWrapper> g_input_;
GradientOpsMeta() {}
GradientOpsMeta(
const vector<OperatorDef>& ops,
const vector<GradientWrapper>& v)
: ops_(ops), g_input_(v) {}
};
class TORCH_API GradientMakerBase {
public:
GradientMakerBase(
const OperatorDef& def,
const vector<GradientWrapper>& g_output)
: def_(def), g_output_(g_output), g_input_(def.input_size()){};
virtual ~GradientMakerBase() {}
virtual bool CopyDeviceOption() const {
return true;
}
virtual bool CopyEngine() const {
return true;
}
virtual bool CopyArguments() const {
return true;
}
virtual void VerifyOp() const {
auto* schema = OpSchemaRegistry::Schema(def_.type());
if (schema) {
CAFFE_ENFORCE(
schema->Verify(def_),
"(GradientMaker) Operator def did not pass schema checking: ",
ProtoDebugString(def_));
}
}
/**
* @brief Returns the gradient ops meta.
*
* If your gradient op generator only use standard input and output
* manipulations, you can simply implement GetGradientDefs() that
* returns vector<OperatorDef>. In that, you can call GI, GI_V and GI_I
* that will automatically create the gradient registration for you.
*
* If you need to do custom gradient name registration, overload this
* function directly.
*/
virtual GradientOpsMeta Get() {
VerifyOp();
vector<OperatorDef> new_defs = GetGradientDefs();
for (auto& opdef : new_defs) {
opdef.set_is_gradient_op(true);
}
return GradientOpsMeta(new_defs, g_input_);
};
const OperatorDef& Def() const {
return def_;
}
protected:
virtual vector<OperatorDef> GetGradientDefs() {
CAFFE_NOT_IMPLEMENTED;
}
// Helper functions to return names for the gradient computation.
// I(idx), O(idx): return the input and output names.
// GO(idx): return the name of the gradient for output idx.
// GI(idx), GI_I(idx), GI_V(idx): return the name of the gradient for
// input idx, and also registers that name into the gradient
// registry to be returned.
string I(const int i) {
CAFFE_ENFORCE((i >= 0) && (i < def_.input().size()));
return def_.input(i);
}
string O(const int i) {
CAFFE_ENFORCE((i >= 0) && (i < def_.output().size()));
return def_.output(i);
}
string GI(const int i) {
CAFFE_ENFORCE(
!g_input_.at(i).IsSparse(),
"Input ",
def_.input(i),
" already set to sparse.");
g_input_.at(i).dense_ = GradientName(def_.input(i));
return GradientName(def_.input(i));
}
string GI_I(const int i) {
CAFFE_ENFORCE(
!g_input_.at(i).IsDense(),
"Input ",
def_.input(i),
" already set to dense.");
g_input_.at(i).indices_ = GradientSliceIndices(def_.input(i));
return GradientSliceIndices(def_.input(i));
}
string GI_V(const int i) {
CAFFE_ENFORCE(
!g_input_.at(i).IsDense(),
"Input ",
def_.input(i),
" already set to dense.");
g_input_.at(i).values_ = GradientSliceValues(def_.input(i));
return GradientSliceValues(def_.input(i));
}
string GO(const int i) {
CAFFE_ENFORCE(
g_output_.at(i).IsDense(),
"Gradient of output ",
def_.output(i),
(g_output_.at(i).IsSparse() ? " is sparse (expected dense)."
: " is not provided!"));
return g_output_.at(i).dense_;
}
string GO_I(const int i) {
CAFFE_ENFORCE(
g_output_.at(i).IsSparse(),
"Gradient of output ",
def_.output(i),
(g_output_.at(i).IsDense() ? " is dense (expected sparse)."
: " is not provided!"));
return g_output_.at(i).indices_;
}
string GO_V(const int i) {
CAFFE_ENFORCE(
g_output_.at(i).IsSparse(),
"Gradient of output ",
def_.output(i),
(g_output_.at(i).IsDense() ? " is dense (expected sparse)."
: " is not provided!"));
return g_output_.at(i).values_;
}
const GradientWrapper& GradOut(int i) {
return g_output_.at(i);
}
// Function to add a gradient pair to map.
void SetDense(const int i, const string& name) {
CAFFE_ENFORCE(
!g_input_.at(i).IsSparse(),
"Input ",
def_.input(i),
" already set to sparse.");
g_input_.at(i).dense_ = name;
}
void SetSparse(const int i, const string& indices, const string& values) {
CAFFE_ENFORCE(
!g_input_.at(i).IsDense(),
"Input ",
def_.input(i),
" already set to dense.");
g_input_.at(i).indices_ = indices;
g_input_.at(i).values_ = values;
}
/**
* @brief a helper function to allow one to create one single operator
* def, which is usually the case for many simple operators.
*/
template <class... Args>
inline static vector<OperatorDef> SingleGradientDef(const Args&... args) {
return vector<OperatorDef>{CreateOperatorDef(args...)};
}
public:
/**
* Returns map that returns the parameters that the gradients are for.
*/
static CaffeMap<string, string> MatchGradsToParams(const OperatorDef& op) {
// NOTE: how to go beyond string-matching?
CaffeMap<string, string> m;
for (auto& out : op.output()) {
if (IsGradientBlob(out)) {
m[out] = out.substr(0, out.length() - 5);
}
}
return m;
}
private:
// Utility functions for gradient name computation. We don't expose them
// in order to discourage the use of such names explicitly.
static string GradientName(const string& name) {
return name + "_grad";
}
static bool IsGradientBlob(const string& name) {
return name.length() > 5 && name.find("_grad") == name.length() - 5;
}
static string GradientNameToParam(const string& name) {
CHECK(IsGradientBlob(name));
return name.substr(0, name.length() - 5);
}
static string GradientSliceIndices(const string& name) {
return name + "_grad_indices";
}
static string GradientSliceValues(const string& name) {
return name + "_grad_values";
}
protected:
// We make the member variables protected in case someone wants to write
// a fully custom Get() function.
const OperatorDef& def_;
const vector<GradientWrapper>& g_output_;
vector<GradientWrapper> g_input_;
};
/**
* @brief A helper class to indicate that the operator does not need gradient
* computation.
*
* Use the macro NO_GRADIENT to register operators that do not have gradients.
* Note that this is different fron SHOULD_NOT_DO_GRADIENT: the latter means
* that the gradient computation should not flow through it at all, and throws
* an error if it is called.
*/
class TORCH_API NoGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return vector<OperatorDef>();
}
};
/**
* @brief A helper class to indicate that the operator should have no gradient.
*
* This is used when the operator definition is designed to not have a gradient.
* Calling a gradient on this operator def will cause Caffe2 to quit.
*/
struct ThrowInTheTowelIfGradientIsCalled : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
GradientOpsMeta Get() override {
CAFFE_THROW("One should not call gradient for operator ", def_.type(), ".");
}
};
/**
* @brief A helper class to indicate that the gradient mechanism is not ready.
*
* This should only be used sparsely when the gradient does exist, but we have
* not implemented it yet and are using this as a lazy excuse. Eventually, a
* gradient operator should be implemented.
*/
struct GradientNotImplementedYet : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
GradientOpsMeta Get() override {
CAFFE_THROW(
"Operator ",
def_.type(),
" should have a gradient but is not implemented yet.");
}
};
C10_DECLARE_REGISTRY(
GradientRegistry,
GradientMakerBase,
const OperatorDef&,
const vector<GradientWrapper>&);
#ifdef CAFFE2_NO_GRADIENT_OPS
#define REGISTER_GRADIENT(name, ...) /* No gradients. */
#define REGISTER_GRADIENT_STR(str_name, ...) /* No gradients. */
#else
#define REGISTER_GRADIENT(name, ...) \
C10_REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
#define REGISTER_GRADIENT_STR(str_name, ...) \
C10_REGISTER_TYPED_CLASS(GradientRegistry, str_name, __VA_ARGS__)
#endif
// NO_GRADIENT means that the operator does not need any gradient computation.
#define NO_GRADIENT(name) REGISTER_GRADIENT(name, NoGradient)
// SHOULD_NOT_DO_GRADIENT means that the operator is not designed to have
// gradient operators. If you attempt to call the gradient, a log fatal will
// occur.
#define SHOULD_NOT_DO_GRADIENT(name) \
REGISTER_GRADIENT(name, ThrowInTheTowelIfGradientIsCalled)
#define GRADIENT_NOT_IMPLEMENTED_YET(name) \
REGISTER_GRADIENT(name, GradientNotImplementedYet)
/**
* @brief Gets the GradientOpsMeta for the given operator def.
*/
TORCH_API GradientOpsMeta GetGradientForOp(
const OperatorDef& def,
const vector<GradientWrapper>& g_output);
} // namespace caffe2
#endif // CAFFE2_CORE_OPERATOR_GRADIENT_H_