Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
rewrite unittest (#889)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech authored Aug 19, 2022
1 parent 7febbab commit 0431271
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 106 deletions.
125 changes: 67 additions & 58 deletions cinn/hlir/framework/graph_compiler.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -310,74 +310,24 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFuncWithIRSchedule(
input_output_nodes.push_back(id);
}

for (auto& i : GetAllNodeData(node)) {
VLOG(3) << "cinn_inputs.push_back " << i->id();
cinn_inputs.push_back(common::CINNValue(i->id()));
}

std::vector<Type> out_types;
std::vector<std::vector<int>> out_shapes;
auto node_datas = GetAllNodeData(node);
for (auto node_data : node_datas) {
// collect output node data name.
out_types.push_back(type_dict_.at(node_data->id()));
out_shapes.push_back(shape_dict_.at(node_data->id()));
input_output_nodes.push_back(node_data->id());
std::string out_name = node_data->id();
VLOG(3) << "cinn_inputs.push_back " << out_name;
cinn_inputs.push_back(common::CINNValue(out_name));
out_types.push_back(type_dict_.at(out_name));
out_shapes.push_back(shape_dict_.at(out_name));
input_output_nodes.push_back(out_name);
}

// 2.Call Op's Compute function, using the default stages and LowerVec to get IR tree.
auto impl =
OpStrategy::SelectImpl(cinn_strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, target_));
common::CINNValuePack C = impl->fcompute(common::CINNValuePack{cinn_inputs});
auto all_arg_tensors = tensor_inputs;

// 3. Collect tensors and arguments
// Add output tensors to all_arg_tensors
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
// checkout whether the tensor is with buffer.
if (!temp.as_tensor_ref()->buffer.defined() || target_ != common::DefaultNVGPUTarget()) {
all_arg_tensors.push_back(temp.as_tensor_ref());
}
}

poly::StageMap stages = C.back();
std::string func_name_prefix = "fn_";
auto func = lang::LowerVec(func_name_prefix + node->id(), stages, all_arg_tensors, {}, {}, nullptr, target_, true);

std::vector<common::CINNValue> schedule_inputs;
for (auto& f : func) {
schedule_inputs.push_back(common::CINNValue(f->body));
}
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
schedule_inputs.push_back(common::CINNValue(temp.as_tensor_ref()->name));
}

// 4. Call Op's Schedule function, optimizing the IR tree by new IR schedule
common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs});

// 5. Optimize the LoweredFunc
VLOG(3) << "expr_pack.size() is : " << expr_pack.size();
std::vector<ir::LoweredFunc> res;
for (int i = 0; i < expr_pack.size(); i++) {
if (func.size() > expr_pack.size()) {
auto new_args = lang::GetArgs(func[i]->body, input_output_nodes);
func[i]->args = new_args;
}
auto temp_buffers = lang::GetTempBuffers(all_arg_tensors, stages, func[i]->body);
func[i]->temp_bufs = temp_buffers;
func[i]->PrepareBufferCastExprs();
res.push_back(func[i]);
}
for (int i = 0; i < res.size(); i++) {
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(res[i]->body));
#endif
res[i] = optim::Optimize(Expr(res[i]), target_, false).as_lowered_func_ref();
}

// 6. Return the result.
auto res =
GetFuncFromImpl(impl, common::CINNValuePack{cinn_inputs}, tensor_inputs, input_output_nodes, node->id(), target_);
return res;
}

Expand Down Expand Up @@ -1547,6 +1497,65 @@ std::vector<ir::LoweredFunc> GraphCompiler::FusedNodeGroupToLoweredFunc(
return funcs;
}

std::vector<ir::LoweredFunc> GetFuncFromImpl(const std::shared_ptr<OpImpl>& impl,
const common::CINNValuePack& cinn_inputs,
std::vector<ir::Tensor>& all_arg_tensors,
const std::vector<std::string>& input_output_nodes,
const std::string& node_id,
const Target& target) {
// 1.Call Op's Compute function, using the default stages and LowerVec to get IR tree.
common::CINNValuePack C = impl->fcompute(cinn_inputs);

// 2. Collect tensors and arguments
// Add output tensors to all_arg_tensors
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
// checkout whether the tensor is with buffer.
if (!temp.as_tensor_ref()->buffer.defined() || target != common::DefaultNVGPUTarget()) {
all_arg_tensors.push_back(temp.as_tensor_ref());
}
}

poly::StageMap stages = C.back();
std::string func_name_prefix = "fn_";
auto func = lang::LowerVec(func_name_prefix + node_id, stages, all_arg_tensors, {}, {}, nullptr, target, true);

std::vector<common::CINNValue> schedule_inputs;
for (auto& f : func) {
schedule_inputs.push_back(common::CINNValue(f->body));
}
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
schedule_inputs.push_back(common::CINNValue(temp.as_tensor_ref()->name));
}

// 3. Call Op's Schedule function, optimizing the IR tree by new IR schedule
common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs});

// 4. Optimize the LoweredFunc
VLOG(3) << "expr_pack.size() is : " << expr_pack.size();
std::vector<ir::LoweredFunc> res;
for (int i = 0; i < expr_pack.size(); i++) {
if (func.size() > expr_pack.size()) {
auto new_args = lang::GetArgs(func[i]->body, input_output_nodes);
func[i]->args = new_args;
}
auto temp_buffers = lang::GetTempBuffers(all_arg_tensors, stages, func[i]->body);
func[i]->temp_bufs = temp_buffers;
func[i]->PrepareBufferCastExprs();
res.push_back(func[i]);
}
for (int i = 0; i < res.size(); i++) {
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(res[i]->body));
#endif
res[i] = optim::Optimize(Expr(res[i]), target, false).as_lowered_func_ref();
}

// 5. Return the result.
return res;
}

} // namespace framework
} // namespace hlir
} // namespace cinn
8 changes: 8 additions & 0 deletions cinn/hlir/framework/graph_compiler.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ std::shared_ptr<Scope> BuildScope(Target target,
const std::shared_ptr<Graph>& graph,
std::shared_ptr<Scope> scope = nullptr);

// Given params, lower the op to LoweredFunc using new IR Schedule
std::vector<ir::LoweredFunc> GetFuncFromImpl(const std::shared_ptr<OpImpl>& impl,
const common::CINNValuePack& cinn_inputs,
std::vector<ir::Tensor>& tensor_inputs,
const std::vector<std::string>& input_output_nodes,
const std::string& node_id,
const Target& target);

} // namespace framework
} // namespace hlir
} // namespace cinn
2 changes: 1 addition & 1 deletion cinn/hlir/op/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ std::shared_ptr<OpStrategy> StrategyForBroadcastTo(const framework::NodeAttr &at
if (target.arch == Target::Arch::NVGPU) {
pe::IRCudaScheduleInjective(ir_sch, out_shape, target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, out_shape, target);
pe::IRScheduleInjectiveCPU(ir_sch, out_shape, target, false);
}
std::vector<CINNValue> res{CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = CINNValuePack{res};
Expand Down
Loading

0 comments on commit 0431271

Please sign in to comment.