Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Fix ir copy of var #1470

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions cinn/optim/transform_gpu_forloop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,23 +170,31 @@ void CudaSyncThreadsDropIfThenElse(Expr *expr) {
Mutator()(expr);
}

class RestructureVarNodes : public ir::IRMutator<> {
public:
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }

private:
void Visit(const ir::_Var_ *var, Expr *op) override { *op = IRCopy(*op); }
};

class ReplaceIndexToBindExpr : public ir::IRMutator<> {
public:
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }

private:
void Visit(const ir::ScheduleBlockRealize *op, Expr *expr) override {
auto *schedule_block_realize = expr->As<ir::ScheduleBlockRealize>();
ir::ScheduleBlockRealize *schedule_block_realize = expr->As<ir::ScheduleBlockRealize>();
CHECK(schedule_block_realize->schedule_block.As<ir::ScheduleBlock>());
auto iter_values = schedule_block_realize->iter_values;
auto body_copy = schedule_block_realize->schedule_block.As<ir::ScheduleBlock>()->body;
auto iter_vars = schedule_block_realize->schedule_block.As<ir::ScheduleBlock>()->iter_vars;
std::vector<ir::Expr> iter_values = schedule_block_realize->iter_values;
ir::Expr body = schedule_block_realize->schedule_block.As<ir::ScheduleBlock>()->body;
std::vector<ir::Var> iter_vars = schedule_block_realize->schedule_block.As<ir::ScheduleBlock>()->iter_vars;

CHECK_EQ(iter_values.size(), iter_vars.size());
for (int idx = 0; idx < iter_values.size(); ++idx) {
ReplaceVarWithExpr(&body_copy, iter_vars[idx], iter_values[idx]);
ReplaceVarWithExpr(&body, iter_vars[idx], iter_values[idx]);
}
ir::IRMutator<>::Visit(&body_copy, &body_copy);
ir::IRMutator<>::Visit(&body, &body);
}
};

Expand Down Expand Up @@ -594,6 +602,11 @@ class ReplaceVarToZero : public ir::IRMutator<> {

void OptimizeExprGPU(Expr *expr) {
VLOG(2) << "Before Optimize Expr:\n" << *expr;

// copy var nodes to prevent one modification leading to multiple changes
RestructureVarNodes restructure_var_nodes;
restructure_var_nodes(expr);

// replace var to bind expr
ReplaceIndexToBindExpr replace_index_to_bind_expr;
replace_index_to_bind_expr(expr);
Expand Down