Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

math API support int tensor autocast to float32 易用性提升 #69252

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,48 @@
"remainder_": ["x", "y"],
}

# ops support casting int tensor into float32 to do forward calculation
type_autocast_op_list = {
"acos": ["x"],
"acosh": ["x"],
"asin": ["x"],
"asinh": ["x"],
"atan": ["x"],
"atanh": ["x"],
"ceil": ["x"],
"cos": ["x"],
"cosh": ["x"],
"digamma": ["x"],
"erf": ["x"],
"erfinv": ["x"],
"floor": ["x"],
"i0": ["x"],
"i0e": ["x"],
"i1": ["x"],
"i1e": ["x"],
"lgamma": ["x"],
"logcumsumexp": ["x"],
"logit": ["x"],
"logsumexp": ["x"],
"polygamma": ["x"],
"reciprocal": ["x"],
"rsqrt": ["x"],
"sigmoid": ["x"],
"sin": ["x"],
"sinh": ["x"],
"sqrt": ["x"],
"stanh": ["x"],
"tan": ["x"],
"tanh": ["x"],
}

# ops support casting int tensor into float32 to do forward calculation,
# and it is valid to cast float32 gradient back to int tensor.
type_autocast_valid_grad_op_list = {
"ceil",
"floor",
}

# dict of special api that forward api's output will affect backward api's output
# backward api's output usually affected by backward api's input

Expand Down Expand Up @@ -327,6 +369,8 @@ class {} : public egr::GradNodeBase {{
// AMP Logic
{}
// Type promotion Logic
{}
// Type autocast Logic
{}
// Layout autotune
{}
Expand Down Expand Up @@ -404,6 +448,8 @@ class {} : public egr::GradNodeBase {{
// AMP Logic
{}
// Type promotion Logic
{}
// Type autocast Logic
{}
// Layout autotune
{}
Expand Down Expand Up @@ -618,6 +664,15 @@ class {} : public egr::GradNodeBase {{
}}
"""

TYPE_AUTOCAST_LOGIC_TEMPLATE = """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只有int情况下会被cast,还是所有情况下会被cast,不能引入不兼容的性能下降问题,只能对于之前不支持会报错的情况去cast

Copy link
Contributor Author

@NKNaN NKNaN Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只有在指定的op列表里且输入是int的时候会cast,应该不会不兼容。CE-Framework 的报错可能需要修改 PaddleTest 中的一些单测,因为有一些单测里面包含int输入会报错的case。

在 aistudio 上跑的结果
image

在 PaddleTest 中注释掉 paddle.sin 的单测:PaddlePaddle/PaddleTest#2992

if (phi::NeedTypeAutoCast({op_func_name}, {x}.dtype())) {{
VLOG(5) << "math operation got integer input data type, run type autocast.";
LOG_FIRST_N(WARNING, 1) << "math operation got integer input data type, run type autocast, this may cause data type been changed.";
{op_name}
auto new_{x} = egr::PromoteCast("{x}", {x}, phi::DataType::FLOAT32, {trace_backward});
{return_value}
}}
"""

LAYOUT_LOGIC_TEMPLATE = """
if (egr::Controller::Instance().UseLayoutAutoTune()) {{
Expand Down Expand Up @@ -1563,6 +1618,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):

amp_inputs_call_list = ["" for i in range(num_inputs)]
type_promote_inputs_call_list = ["" for i in range(num_inputs)]
type_autocast_inputs_call_list = ["" for i in range(num_inputs)]
amp_tensors_vector_list = []
amp_tensors_vector_optional_list = []
amp_autocast_list = []
Expand Down Expand Up @@ -1591,6 +1647,11 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
type_promote_inputs_call_list[pos] = f"new_{name}"
else:
type_promote_inputs_call_list[pos] = f"{name}"
if forward_api_name in type_autocast_op_list:
if name in type_autocast_op_list[forward_api_name]:
type_autocast_inputs_call_list[pos] = f"new_{name}"
else:
type_autocast_inputs_call_list[pos] = f"{name}"
if IsPlainTensorType(ttype):
if is_optional:
if (
Expand Down Expand Up @@ -1682,6 +1743,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
inputs_call_list[pos] = name
amp_inputs_call_list[pos] = name
type_promote_inputs_call_list[pos] = name
type_autocast_inputs_call_list[pos] = name
if default_val is not None:
inputs_args_declaration_list[pos] = (
f"{atype} {name} = {default_val}"
Expand Down Expand Up @@ -1971,6 +2033,31 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
)
else:
type_promotion_logic_str = f'\n VLOG(5) << " No Type Promotion for {forward_ad_function_name} api. "; '

# Forward type autocast logic
if forward_api_name in type_autocast_op_list:
# only support one inputs
op_func_name = f'"{forward_api_name}"'
x = type_autocast_op_list[forward_api_name][0]
type_autocast_inputs_call_args_str = ", ".join(
type_autocast_inputs_call_list
)
trace_backward = (
forward_api_name in type_autocast_valid_grad_op_list
) and (not self.is_forward_only)
trace_backward = str(trace_backward).lower()
type_autocast_call_list = f"return {forward_ad_function_name}({type_autocast_inputs_call_args_str});"

type_autocast_logic_str = TYPE_AUTOCAST_LOGIC_TEMPLATE.format(
op_func_name=op_func_name,
x=x,
op_name=kernel_trans2_op_name_str,
trace_backward=trace_backward,
return_value=type_autocast_call_list,
)
else:
type_autocast_logic_str = f'\n VLOG(5) << " No Type Autocast for {forward_ad_function_name} api. "; '

# Forward layout autotune
layout_autotune_list_str = " ".join(
layout_autotune_list
Expand Down Expand Up @@ -2020,6 +2107,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
dygraph_event_str,
amp_logic_str,
type_promotion_logic_str,
type_autocast_logic_str,
layout_logic_str,
forward_api_name,
before_log_str,
Expand All @@ -2044,6 +2132,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
dygraph_event_str,
amp_logic_str,
type_promotion_logic_str,
type_autocast_logic_str,
layout_logic_str,
inputs_autograd_meta_str,
forward_api_name,
Expand Down
100 changes: 100 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,48 @@
"remainder_": ["x", "y"],
}

# ops support casting int tensor into float32 to do forward calculation
type_autocast_op_list = {
"acos": ["x"],
"acosh": ["x"],
"asin": ["x"],
"asinh": ["x"],
"atan": ["x"],
"atanh": ["x"],
"ceil": ["x"],
"cos": ["x"],
"cosh": ["x"],
"digamma": ["x"],
"erf": ["x"],
"erfinv": ["x"],
"floor": ["x"],
"i0": ["x"],
"i0e": ["x"],
"i1": ["x"],
"i1e": ["x"],
"lgamma": ["x"],
"logcumsumexp": ["x"],
"logit": ["x"],
"logsumexp": ["x"],
"polygamma": ["x"],
"reciprocal": ["x"],
"rsqrt": ["x"],
"sigmoid": ["x"],
"sin": ["x"],
"sinh": ["x"],
"sqrt": ["x"],
"stanh": ["x"],
"tan": ["x"],
"tanh": ["x"],
}

# ops support casting int tensor into float32 to do forward calculation,
# and it is valid to cast float32 gradient back to int tensor.
type_autocast_valid_grad_op_list = {
"ceil",
"floor",
}

PD_MANUAL_API_LIST = {
'embedding_grad',
'assign',
Expand Down Expand Up @@ -140,6 +182,8 @@
{amp_logic}
// Type Promotion Logic
{type_promotion_logic}
// Type Autocast Logic
{type_autocast_logic}
{check_data_type}
{handle_optional_inputs}
{in_combine}
Expand Down Expand Up @@ -196,6 +240,18 @@
}}
"""

TYPE_AUTOCAST_LOGIC_TEMPLATE = """
auto x_dtype = paddle::imperative::GetDataType({x});
if (phi::NeedTypeAutoCast("{op_name}", x_dtype)) {{
VLOG(5) << "math operation got integer input data type, run type autocast.";
LOG_FIRST_N(WARNING, 1) << "math operation got integer input data type, run type autocast, this may cause data type been changed.";
//{op_name}
if (!{trace_backward}) {{ SetStopGradient({x}); }}
auto new_{x} = pir::PromoteCast("{x}", {x}, phi::DataType::FLOAT32);
return paddle::dialect::{op_name}({args});
}}
"""

OP_DISPATCH_TEMPLATE = """
if ({cond}) {{
{inner_code}
Expand Down Expand Up @@ -861,6 +917,44 @@ def _gen_type_promotion_logic(self, op_info, op_name):

return type_promotion_logic_str

def _gen_type_autocast_args(self, op_info, op_name):
type_autocast_inputs_call_list = []
for name in op_info.input_name_list:
if op_name in type_autocast_op_list:
if name in type_autocast_op_list[op_name]:
type_autocast_inputs_call_list.append(f"new_{name}")
else:
type_autocast_inputs_call_list.append(f"{name}")

attr_list = op_info.attribute_name_list
args = type_autocast_inputs_call_list + attr_list
return ', '.join(args)

def _gen_type_autocast_logic(self, op_info, op_name):
if op_name in type_autocast_op_list:
x = type_autocast_op_list[op_name][0]

type_autocast_inputs_call_args_str = self._gen_type_autocast_args(
op_info, op_name
)
trace_backward = op_name in type_autocast_valid_grad_op_list
trace_backward = str(trace_backward).lower()

if op_info.is_sparse_op:
op_name += "sp_" if op_name[-1] == "_" else "_sp"
type_autocast_logic_str = TYPE_AUTOCAST_LOGIC_TEMPLATE.format(
op_name=op_name,
x=x,
trace_backward=trace_backward,
args=type_autocast_inputs_call_args_str,
)
else:
type_autocast_logic_str = (
f'\n VLOG(5) << " No Type Autocast for {op_name} api. "; '
)

return type_autocast_logic_str

def _gen_check_data_type(self, op_info, op_name):
mapping_input_name_to_type = dict(
zip(op_info.input_name_list, op_info.input_type_list)
Expand Down Expand Up @@ -1044,6 +1138,9 @@ def _gen_one_impl(
type_promotion_logic=self._gen_type_promotion_logic(
op_info, op_name
),
type_autocast_logic=self._gen_type_autocast_logic(
op_info, op_name
),
check_data_type=self._gen_check_data_type(
op_info, kernel_name
),
Expand Down Expand Up @@ -1109,6 +1206,9 @@ def _gen_one_impl(
type_promotion_logic=self._gen_type_promotion_logic(
op_info, op_name
),
type_autocast_logic=self._gen_type_autocast_logic(
op_info, op_name
),
check_data_type=self._gen_check_data_type(op_info, kernel_name),
handle_optional_inputs=self._gen_handle_optional_inputs(
op_info
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/common/type_promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ static std::unordered_set<std::string> support_promotion_ops = {
"less_than", "less_equal", "greater_than", "greater_equal",
};

static std::unordered_set<std::string> support_autocast_ops = {
"acos", "acosh", "asin", "asinh", "atan", "atanh",
"ceil", "cos", "cosh", "digamma", "erf", "erfinv",
"floor", "lgamma", "logcumsumexp", "logit", "logsumexp", "polygamma",
"reciprocal", "rsqrt", "sin", "sinh", "sqrt", "stanh",
"tan", "tanh", "i0", "i0e", "i1", "i1e",
"sigmoid",
};

inline bool is_support_float(DataType dtype) {
if (dtype == DataType::FLOAT16 || dtype == DataType::FLOAT32 ||
dtype == DataType::FLOAT64 || dtype == DataType::BFLOAT16) {
Expand Down Expand Up @@ -264,4 +273,12 @@ inline bool NeedTypePromotionOldIr(const std::string& op_name,
}
}

inline bool NeedTypeAutoCast(const std::string& op_name,
const DataType& x_dtype) {
if (support_autocast_ops.find(op_name) != support_autocast_ops.end() &&
(is_support_int(x_dtype))) {
return true;
}
return false;
}
} // namespace phi
Loading