Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Organize op lower #1532

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ TEST(AutoInline, AddReluInline) {
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(dtype_dict, shape_dict, target);

EXPECT_EQ(graph->fusion_groups.size(), 1UL);
std::vector<ir::LoweredFunc> funcs = op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]);
std::vector<ir::LoweredFunc> funcs = op_lowerer->Lower(graph->fusion_groups[0], false, false);

VLOG(6) << "Expr before auto inline: " << funcs[0]->body;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) {

TEST_F(TestMultiLevelTiling, Pool2d) {
default_input_names = {"input"};
default_output_names = {"var_0"};
std::vector<int32_t> input_shape{2, 8, 16, 16};
std::vector<int32_t> output_shape{2, 8, 8, 8};
default_output_names = {"var_0", "pad_temp_0"};
std::vector<std::vector<int32_t>> input_shapes{{2, 8, 16, 16}};
std::vector<std::vector<int32_t>> output_shapes{{2, 8, 8, 8}, {2, 8, 18, 18}};
std::string pooling_type = "max";
std::vector<int> ksize{3, 3};
std::vector<int> strides{2, 2};
Expand All @@ -374,7 +374,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
std::string data_format = "NCHW";
bool adaptive = false;
std::string padding_algorithm = "EXPLICIT";
frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build({{"input", input_shape}},
frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build({{"input", input_shapes[0]}},
{{"pool_type", pooling_type},
{"kernel_size", ksize},
{"stride_size", strides},
Expand Down Expand Up @@ -411,107 +411,104 @@ TEST_F(TestMultiLevelTiling, Pool2d) {
{
ScheduleBlock(root)
{
serial for (i, 0, 2)
{
serial for (j, 0, 8)
serial for (i, 0, 2)
{
serial for (k, 0, 18)
serial for (j, 0, 8)
{
serial for (a, 0, 18)
serial for (k, 0, 18)
{
ScheduleBlock(pad_temp_0)
serial for (a, 0, 18)
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
ScheduleBlock(pad_temp_0)
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
}
}
}
}
}
}
}
}
}
} // end Expr 0
Expr 1 {
{
ScheduleBlock(root_0)
{
{
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{
thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4)
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{
serial for (i_1, 0, 1)
thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4)
{
serial for (j_1, 0, 4)
serial for (i_1, 0, 1)
{
serial for (k_1, 0, 1)
serial for (j_1, 0, 4)
{
serial for (a_1, 0, 4)
serial for (k_1, 0, 1)
{
ScheduleBlock(var_0__reduce_init)
serial for (a_1, 0, 4)
{
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
ScheduleBlock(var_0__reduce_init)
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
}
}
}
}
}
}
}
{
serial for (kernel_idx, 0, 3)
{
serial for (kernel_idx_0, 0, 3)
serial for (kernel_idx, 0, 3)
{
serial for (ax0_ax1_ax2_ax3_fused, 0, 28)
serial for (kernel_idx_0, 0, 3)
{
ScheduleBlock(pad_temp_0_shared_temp_buffer)
serial for (ax0_ax1_ax2_ax3_fused, 0, 28)
{
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
ScheduleBlock(pad_temp_0_shared_temp_buffer)
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
}
}
}
}
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 4)
serial for (i_1, 0, 1)
{
serial for (k_1, 0, 1)
serial for (j_1, 0, 4)
{
serial for (a_1, 0, 4)
serial for (k_1, 0, 1)
{
ScheduleBlock(var_0_local_temp_buffer)
serial for (a_1, 0, 4)
{
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
ScheduleBlock(var_0_local_temp_buffer)
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
}
}
}
}
}
}
}
}
}
serial for (ax0_0, 0, 1)
{
serial for (ax1_0, 0, 4)
serial for (ax0_0, 0, 1)
{
serial for (ax2_0, 0, 1)
serial for (ax1_0, 0, 4)
{
serial for (ax3_0, 0, 4)
serial for (ax2_0, 0, 1)
{
ScheduleBlock(var_0)
serial for (ax3_0, 0, 4)
{
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
ScheduleBlock(var_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
}
}
}
}
Expand All @@ -524,7 +521,7 @@ Expr 1 {
}
}
}
} // end Expr 1
} // end Expr 0
)ROC";
ASSERT_EQ(ir, expected_ir);

Expand All @@ -539,8 +536,8 @@ Expr 1 {
BuildIRModule(MakeIRSchedule(pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names,
default_output_names,
{input_shape},
{output_shape},
input_shapes,
output_shapes,
target_);
}

Expand Down
6 changes: 1 addition & 5 deletions cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,7 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(const frontend::Program& test
auto& shape_dict = graph->GetMutableAttrs<absl::flat_hash_map<std::string, hlir::framework::shape_t>>("infershape");
hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_);

if (apply_manual_schedule) {
lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front());
} else {
lowered_funcs_ = op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front());
}
lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front(), apply_manual_schedule, apply_manual_schedule);
CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty";

std::vector<Expr> bodys;
Expand Down
2 changes: 1 addition & 1 deletion cinn/auto_schedule/task/tune_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void TuneTask::Initialize(const absl::flat_hash_map<std::string, hlir::framework
op_lowerer = lower_handler;

// Set lowered_funcs and analyze output names.
this->lowered_funcs = op_lowerer->LowerWithoutSchedule(subgraph);
this->lowered_funcs = op_lowerer->Lower(subgraph, false, false);
this->output_names = GetOutputNamesFromLoweredFunc(this->lowered_funcs);
this->serialized_key = SerializeToString(shape_dict, dtype_dict);
}
Expand Down
2 changes: 1 addition & 1 deletion cinn/auto_schedule/tests/performance_comparison_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class PerformanceTester : public ::testing::Test {
compile_options.groups = graph->fusion_groups;

for (auto group : graph->fusion_groups) {
compile_options.lowered_funcs.push_back(op_lowerer->LowerWithoutSchedule(group));
compile_options.lowered_funcs.push_back(op_lowerer->Lower(group, false, false));
}

VLOG(3) << "===========================No Schedule LoweredFunc Begin===========================";
Expand Down
Loading