Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Add computeinline and fix bug (#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech authored Sep 28, 2020
1 parent fb0683c commit 5194cf3
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 234 deletions.
6 changes: 3 additions & 3 deletions cinn/frontend/syntax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Variable Program::conv2d(const Variable& a,
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(2);
return instr.GetOutput(0);
}

Variable Program::depthwise_conv2d(const Variable& a,
Expand All @@ -57,7 +57,7 @@ Variable Program::depthwise_conv2d(const Variable& a,
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(1);
return instr.GetOutput(0);
}

Variable Program::pool2d(const Variable& a, const std::unordered_map<std::string, attr_t>& attr_store) {
Expand All @@ -67,7 +67,7 @@ Variable Program::pool2d(const Variable& a, const std::unordered_map<std::string
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(1);
return instr.GetOutput(0);
}

Variable Program::batchnorm(const Variable& a,
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ ir::LoweredFunc GraphCompiler::GetOpFunc(const Node* node) {
std::vector<ir::Tensor> inputs;
std::vector<common::CINNValue> cinn_inputs;
std::vector<std::vector<int>> output_shapes;
LOG(INFO) << "GetOpFunc of op " << node->id();
for (auto& i : node->inlinks_in_order()) {
std::string input_id = i->source()->as<NodeData>()->id();
auto in_shape = shape_dict.at(input_id);
Expand Down
1 change: 0 additions & 1 deletion cinn/hlir/op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@ foreach(cpp ${srcs})
endforeach()

cc_test(test_op_broadcast SRCS op_broadcast_test.cc DEPS core)
cc_test(test_op_nn SRCS op_nn_test.cc DEPS core)
164 changes: 74 additions & 90 deletions cinn/hlir/op/nn.cc

Large diffs are not rendered by default.

130 changes: 46 additions & 84 deletions cinn/hlir/pe/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,37 +45,24 @@ std::vector<ir::Tensor> Conv2d_NCHW(const ir::Tensor &input,
int stride_w,
int dilation_h,
int dilation_w,
const std::vector<std::vector<int>> &output_shapes,
const std::string &output_name) {
CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of Conv2d_NCHW op is not 4! Please check.";
CHECK_EQ(weights->shape.size(), 4U) << "Weight's dimension of Conv2d_NCHW op is not 4! Please check.";
std::vector<Expr> output_shape;
std::vector<Expr> new_weights_shape;
std::vector<Expr> input_pad_shape;
if (output_shapes.size() == 3) {
// already computed by infer_shape
CHECK_EQ(output_shapes[0].size(), 4U) << "The size of output_shapes[0] of Conv2d op is not 4! Please check.";
CHECK_EQ(output_shapes[1].size(), 4U) << "The size of output_shapes[1] of Conv2d op is not 4! Please check.";
CHECK_EQ(output_shapes[2].size(), 4U) << "The size of output_shapes[2] of Conv2d op is not 4! Please check.";
output_shape = {
Expr(output_shapes[2][0]), Expr(output_shapes[2][1]), Expr(output_shapes[2][2]), Expr(output_shapes[2][3])};
new_weights_shape = {
Expr(output_shapes[1][0]), Expr(output_shapes[1][1]), Expr(output_shapes[1][2]), Expr(output_shapes[1][3])};
input_pad_shape = {
Expr(output_shapes[0][0]), Expr(output_shapes[0][1]), Expr(output_shapes[0][2]), Expr(output_shapes[0][3])};
} else {
output_shape = {
input->shape[0], // B
weights->shape[0], // O
Expr((input->shape[2] - ((weights->shape[2] - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1), // H
Expr((input->shape[3] - ((weights->shape[3] - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1) // W
};
new_weights_shape = {weights->shape[0],
weights->shape[1],
dilation_h * (weights->shape[2] - 1) + 1,
dilation_w * (weights->shape[3] - 1) + 1};
input_pad_shape = {input->shape[0], input->shape[1], input->shape[2] + 2 * pad_h, input->shape[3] + 2 * pad_w};
}
output_shape = {
input->shape[0], // B
weights->shape[0], // O
Expr((input->shape[2] - ((weights->shape[2] - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1), // H
Expr((input->shape[3] - ((weights->shape[3] - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1) // W
};
new_weights_shape = {weights->shape[0],
weights->shape[1],
dilation_h * (weights->shape[2] - 1) + 1,
dilation_w * (weights->shape[3] - 1) + 1};
input_pad_shape = {input->shape[0], input->shape[1], input->shape[2] + 2 * pad_h, input->shape[3] + 2 * pad_w};

auto input_pad = Compute(
input_pad_shape,
[=](Expr nn, Expr cc, Expr yy, Expr xx) {
Expand Down Expand Up @@ -123,38 +110,25 @@ std::vector<ir::Tensor> Conv2d_NHWC(const ir::Tensor &input,
int stride_w,
int dilation_h,
int dilation_w,
const std::vector<std::vector<int>> &output_shapes,
const std::string &output_name) {
CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of Conv2d_NHWC op is not 4! Please check.";
CHECK_EQ(weights->shape.size(), 4U) << "Weight's dimension of Conv2d_NHWC op is not 4! Please check.";
std::vector<Expr> output_shape;
std::vector<Expr> new_weights_shape;
std::vector<Expr> input_pad_shape;
if (output_shapes.size() == 3) {
// already computed by infer_shape
CHECK_EQ(output_shapes[0].size(), 4U) << "The size of output_shapes[0] of Conv2d op is not 4! Please check.";
CHECK_EQ(output_shapes[1].size(), 4U) << "The size of output_shapes[1] of Conv2d op is not 4! Please check.";
CHECK_EQ(output_shapes[2].size(), 4U) << "The size of output_shapes[2] of Conv2d op is not 4! Please check.";
output_shape = {
Expr(output_shapes[2][0]), Expr(output_shapes[2][1]), Expr(output_shapes[2][2]), Expr(output_shapes[2][3])};
new_weights_shape = {
Expr(output_shapes[1][0]), Expr(output_shapes[1][1]), Expr(output_shapes[1][2]), Expr(output_shapes[1][3])};
input_pad_shape = {
Expr(output_shapes[0][0]), Expr(output_shapes[0][1]), Expr(output_shapes[0][2]), Expr(output_shapes[0][3])};
} else {
output_shape = {
input->shape[0], // B
Expr((input->shape[1] - ((weights->shape[2] - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1), // H
Expr((input->shape[2] - ((weights->shape[3] - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1), // W
weights->shape[0] // O
};
new_weights_shape = {weights->shape[0],
weights->shape[1],
dilation_h * (weights->shape[2] - 1) + 1,
dilation_w * (weights->shape[3] - 1) + 1};
input_pad_shape = {input->shape[0], input->shape[1] + 2 * pad_h, input->shape[2] + 2 * pad_w, input->shape[3]};
}
auto input_pad = Compute(

output_shape = {
input->shape[0], // B
Expr((input->shape[1] - ((weights->shape[2] - 1) * dilation_h + 1) + 2 * pad_h) / stride_h + 1), // H
Expr((input->shape[2] - ((weights->shape[3] - 1) * dilation_w + 1) + 2 * pad_w) / stride_w + 1), // W
weights->shape[0] // O
};
new_weights_shape = {weights->shape[0],
weights->shape[1],
dilation_h * (weights->shape[2] - 1) + 1,
dilation_w * (weights->shape[3] - 1) + 1};
input_pad_shape = {input->shape[0], input->shape[1] + 2 * pad_h, input->shape[2] + 2 * pad_w, input->shape[3]};
auto input_pad = Compute(
input_pad_shape,
[=](Expr nn, Expr yy, Expr xx, Expr cc) {
auto cond =
Expand Down Expand Up @@ -200,28 +174,20 @@ std::vector<Tensor> Depthwise_Conv2d_NCHW(const Tensor &input,
int pad_w,
int stride_h,
int stride_w,
const std::vector<std::vector<int>> &output_shapes,
const std::string output_name) {
CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of Depthwise_Conv2d_NCHW is not 4! Please check.\n";
CHECK_EQ(weight->shape.size(), 4U) << "Weight's dimension of Depthwise_Conv2d_NCHW is not 4! Please check.\n";
Expr in_h = input->shape[2];
Expr in_w = input->shape[3];
Expr c_m = weight->shape[1]; // channel_multiplier
std::vector<Expr> output_shape;
if (output_shapes.size() == 2) {
// already computed by infer_shape
CHECK_EQ(output_shapes[1].size(), 4U)
<< "The size of output_shapes[1] of Depthwise_Conv2d op is not 4! Please check.";
output_shape = {
Expr(output_shapes[1][0]), Expr(output_shapes[1][1]), Expr(output_shapes[1][2]), Expr(output_shapes[1][3])};
} else {
output_shape = {
input->shape[0], // B
weight->shape[1] * input->shape[1], // O
(input->shape[2] - weight->shape[2] + 2 * pad_h) / stride_h + 1, // H
(input->shape[3] - weight->shape[3] + 2 * pad_w) / stride_w + 1 // W
};
}

output_shape = {
input->shape[0], // B
weight->shape[1] * input->shape[1], // O
(input->shape[2] - weight->shape[2] + 2 * pad_h) / stride_h + 1, // H
(input->shape[3] - weight->shape[3] + 2 * pad_w) / stride_w + 1 // W
};
auto input_pad =
(pad_h == 0 && pad_w == 0) ? Identity(input) : Pad(input, {Expr(0), Expr(0), Expr(pad_h), Expr(pad_w)});

Expand All @@ -245,28 +211,20 @@ std::vector<Tensor> Depthwise_Conv2d_NHWC(const Tensor &input,
int pad_w,
int stride_h,
int stride_w,
const std::vector<std::vector<int>> &output_shapes,
const std::string output_name) {
CHECK_EQ(input->shape.size(), 4U) << "Input's dimension of Depthwise_Conv2d_NCHW is not 4! Please check.\n";
CHECK_EQ(weight->shape.size(), 4U) << "Weight's dimension of Depthwise_Conv2d_NCHW is not 4! Please check.\n";
Expr in_h = input->shape[1];
Expr in_w = input->shape[2];
Expr c_m = weight->shape[1]; // channel_multiplier
std::vector<Expr> output_shape;
if (output_shapes.size() == 2) {
// already computed by infer_shape
CHECK_EQ(output_shapes[1].size(), 4U)
<< "The size of output_shapes[1] of Depthwise_Conv2d op is not 4! Please check.";
output_shape = {
Expr(output_shapes[1][0]), Expr(output_shapes[1][1]), Expr(output_shapes[1][2]), Expr(output_shapes[1][3])};
} else {
output_shape = {
input->shape[0], // B
(input->shape[1] - weight->shape[2] + 2 * pad_h) / stride_h + 1, // H
(input->shape[2] - weight->shape[3] + 2 * pad_w) / stride_w + 1, // W
weight->shape[1] * input->shape[3] // O
};
}

output_shape = {
input->shape[0], // B
(input->shape[1] - weight->shape[2] + 2 * pad_h) / stride_h + 1, // H
(input->shape[2] - weight->shape[3] + 2 * pad_w) / stride_w + 1, // W
weight->shape[1] * input->shape[3] // O
};

auto input_pad =
(pad_h == 0 && pad_w == 0) ? Identity(input) : Pad(input, {Expr(0), Expr(pad_h), Expr(pad_w), Expr(0)});
Expand Down Expand Up @@ -541,7 +499,7 @@ std::vector<Tensor> PoolImpl(const Tensor &tensor,
if (pool_type == "max") {
Expr min_value = ir::min_value(tensor->type());
// Pad the input tensor with the pad_value of type's minimum value
temp = do_pad ? Pad(tensor, pad_before, pad_after, min_value, UniqName("pad_temp")) : Identity(tensor);
temp = do_pad ? Pad(tensor, pad_before, pad_after, min_value, UniqName("pad_temp")) : tensor;
res = Compute(
out_shape,
[=](const std::vector<Expr> &output) {
Expand All @@ -559,7 +517,7 @@ std::vector<Tensor> PoolImpl(const Tensor &tensor,
daxis);
} else if (pool_type == "avg") {
// Pad the input tensor with pad_value zero
temp = do_pad ? Pad(tensor, pad_before, pad_after, 0, UniqName("pad_temp")) : Identity(tensor);
temp = do_pad ? Pad(tensor, pad_before, pad_after, 0, UniqName("pad_temp")) : tensor;
res = Compute(
out_shape,
[=](const std::vector<Expr> &output) {
Expand Down Expand Up @@ -599,7 +557,11 @@ std::vector<Tensor> PoolImpl(const Tensor &tensor,
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
}
return {temp, res};
if (do_pad) {
return {temp, res};
} else {
return {res};
}
}

std::vector<Tensor> Pool1d(const Tensor &tensor,
Expand Down
4 changes: 0 additions & 4 deletions cinn/hlir/pe/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ std::vector<ir::Tensor> Conv2d_NCHW(const ir::Tensor &input,
int stride_w,
int dilation_h,
int dilation_w,
const std::vector<std::vector<int>> &output_shapes,
const std::string &output_name = UniqName("T_Conv2d_NCHW_out"));

/**
Expand All @@ -124,7 +123,6 @@ std::vector<ir::Tensor> Conv2d_NHWC(const ir::Tensor &input,
int stride_w,
int dilation_h,
int dilation_w,
const std::vector<std::vector<int>> &output_shapes,
const std::string &output_name = UniqName("T_Conv2d_NHWC_out"));

/**
Expand All @@ -147,7 +145,6 @@ std::vector<ir::Tensor> Depthwise_Conv2d_NCHW(const ir::Tensor &input,
int pad_w,
int stride_h,
int stride_w,
const std::vector<std::vector<int>> &output_shapes,
const std::string output_name = UniqName("T_depthwise_conv2d_nchw"));

/**
Expand All @@ -170,7 +167,6 @@ std::vector<ir::Tensor> Depthwise_Conv2d_NHWC(const ir::Tensor &input,
int pad_w,
int stride_h,
int stride_w,
const std::vector<std::vector<int>> &output_shapes,
const std::string output_name = UniqName("T_depthwise_conv2d_nhwc"));

ir::Tensor BatchNorm_NCHW(const ir::Tensor &input,
Expand Down
10 changes: 10 additions & 0 deletions cinn/lang/lower_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,14 @@ ir::LoweredFunc LowerImpl::operator()() {
if (arg->is_placeholder_node()) continue;
if (arg->buffer.defined()) continue;
if (arg->body().As<ir::Call>() && arg->body().type().is_void()) continue; // extern call
if (tensor_map.find(arg->name) == tensor_map.end()) {
LOG(INFO) << "Didn't find arg tensor " << arg->name << "in tensor_map.\n"
<< "The function is " << fn_name_ << "\nAnd all the arg tensors are:\n";
for (auto& i : tensor_args_) {
LOG(INFO) << i->name;
}
LOG(FATAL) << "Fatal Error!";
}
Reference(&arg)->buffer = tensor_map.at(arg->name)->buffer;
}
}
Expand All @@ -421,7 +429,9 @@ ir::LoweredFunc LowerImpl::operator()() {
auto func = ir::_LoweredFunc_::Make(fn_name_, func_args, func_body, temp_buffers);

// some necessary modification.
LOG(INFO) << "Before optim::ComputeInlineExpand(&func->body, stages_); in function " << fn_name_;
optim::ComputeInlineExpand(&func->body, stages_);
LOG(INFO) << "After optim::ComputeInlineExpand(&func->body, stages_); in function " << fn_name_;
Target target = cuda_axis_info_.valid() ? common::DefaultNVGPUTarget() : common::DefaultHostTarget();
auto res = optim::Optimize(func, target, FLAGS_cinn_runtime_display_debug_info);

Expand Down
6 changes: 5 additions & 1 deletion cinn/optim/compute_inline_expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> {

TensorInlineExpandMutator(const std::string &tensor_name) : tensor_name(tensor_name) {}

void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }
void operator()(Expr *expr) {
LOG(INFO) << "void operator()(Expr *expr) Begin";
ir::IRMutator<>::Visit(expr, expr);
LOG(INFO) << "void operator()(Expr *expr) End";
}

void Visit(const ir::Load *op, Expr *expr) override {
auto *node = expr->As<ir::Load>();
Expand Down
17 changes: 11 additions & 6 deletions cinn/pybind/framework.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,24 @@ void BindFramework(pybind11::module *m) {
auto impl = OpStrategy::SelectImpl(self[op_ptr](attrs, inputs, out_types, output_shapes, target));
std::vector<common::CINNValue> temp_inputs;
std::vector<ir::Tensor> res;
for (auto tensor : inputs) {
for (auto &tensor : inputs) {
res.push_back(tensor);
temp_inputs.push_back(common::CINNValue(tensor));
}
auto stages = CreateStages(inputs);
temp_inputs.push_back(common::CINNValue(stages));
common::CINNValuePack C = impl->fcompute(common::CINNValuePack{temp_inputs});
C = impl->fschedule(C);
for (int i = 0; i < C.get()->size() - 1; i++) {
poly::StageMap stages = C.back();
// make sure all the tensors in the stages before schedule launch.
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
stages->InsertLazily(temp.as_tensor_ref());
}
C = impl->fschedule(C);
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
res.push_back(temp.as_tensor_ref());
}
return res;
auto func = Lower(key, stages, res);
return func;
});

py::class_<NodeAttr>(*m, "NodeAttr")
Expand Down
22 changes: 2 additions & 20 deletions python/tests/conv2d_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,8 @@ def conv2d_native(inputs_data, input_shape, filter_size, attrs, is_depthwise):
print("output's shape is:", output.shape)

res_shape = output.shape[1:]
pad_shape = list(input_shape)
dilation_shape = list(filter_size_new)
assert len(padding) == 2
assert len(pad_shape) == 4
assert len(dilation_shape) == 4
if data_format == "NCHW":
h_index = 2
w_index = 3
else:
h_index = 1
w_index = 2

pad_shape[h_index] += 2 * padding[0]
pad_shape[w_index] += 2 * padding[1]
dilation_shape[2] = (filter_size_new[2] - 1) * dilation[0] + 1
dilation_shape[3] = (filter_size_new[3] - 1) * dilation[1] + 1

print("pad's shape is:", pad_shape)
print("dilation's shape is:", dilation_shape)
if is_depthwise:
return output, [pad_shape, res_shape]
return output, [res_shape]
else:
return output, [pad_shape, dilation_shape, res_shape]
return output, [res_shape]
Loading

0 comments on commit 5194cf3

Please sign in to comment.