From 308fc046bbd58fbe73f54c4cd65b798379fc57ca Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 8 Nov 2024 17:39:15 +0800 Subject: [PATCH 1/2] add autocast logic --- .../generator/eager_gen.py | 89 +++++++ .../fluid/pir/dialect/op_generator/api_gen.py | 100 +++++++ paddle/phi/common/type_promotion.h | 17 ++ python/paddle/tensor/math.py | 223 +++++++++++++--- python/paddle/tensor/ops.py | 247 +++++++++++++++--- test/legacy_test/test_tensor_type_autocast.py | 214 +++++++++++++++ 6 files changed, 809 insertions(+), 81 deletions(-) create mode 100644 test/legacy_test/test_tensor_type_autocast.py diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 0a1d5742ac630..a82c0b513235e 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -131,6 +131,48 @@ "remainder_": ["x", "y"], } +# ops support casting int tensor into float32 to do forward calculation +type_autocast_op_list = { + "acos": ["x"], + "acosh": ["x"], + "asin": ["x"], + "asinh": ["x"], + "atan": ["x"], + "atanh": ["x"], + "ceil": ["x"], + "cos": ["x"], + "cosh": ["x"], + "digamma": ["x"], + "erf": ["x"], + "erfinv": ["x"], + "floor": ["x"], + "i0": ["x"], + "i0e": ["x"], + "i1": ["x"], + "i1e": ["x"], + "lgamma": ["x"], + "logcumsumexp": ["x"], + "logit": ["x"], + "logsumexp": ["x"], + "polygamma": ["x"], + "reciprocal": ["x"], + "rsqrt": ["x"], + "sigmoid": ["x"], + "sin": ["x"], + "sinh": ["x"], + "sqrt": ["x"], + "stanh": ["x"], + "tan": ["x"], + "tanh": ["x"], +} + +# ops support casting int tensor into float32 to do forward calculation, +# and it is valid to cast float32 gradient back to int tensor. +type_autocast_valid_grad_op_list = { + "ceil", + "floor", +} + # dict of special api that forward api's output will affect backward api's output # backward api's output usually affected by backward api's input @@ -327,6 +369,8 @@ class {} : public egr::GradNodeBase {{ // AMP Logic {} // Type promotion Logic +{} + // Type autocast Logic {} // Layout autotune {} @@ -404,6 +448,8 @@ class {} : public egr::GradNodeBase {{ // AMP Logic {} // Type promotion Logic +{} + // Type autocast Logic {} // Layout autotune {} @@ -618,6 +664,15 @@ class {} : public egr::GradNodeBase {{ }} """ +TYPE_AUTOCAST_LOGIC_TEMPLATE = """ + if (phi::NeedTypeAutoCast({op_func_name}, {x}.dtype())) {{ + VLOG(5) << "math operation got integer input data type, run type autocast."; + LOG_FIRST_N(WARNING, 1) << "math operation got integer input data type, run type autocast, this may cause data type been changed."; + {op_name} + auto new_{x} = egr::PromoteCast("{x}", {x}, phi::DataType::FLOAT32, {trace_backward}); + {return_value} + }} +""" LAYOUT_LOGIC_TEMPLATE = """ if (egr::Controller::Instance().UseLayoutAutoTune()) {{ @@ -1563,6 +1618,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): amp_inputs_call_list = ["" for i in range(num_inputs)] type_promote_inputs_call_list = ["" for i in range(num_inputs)] + type_autocast_inputs_call_list = ["" for i in range(num_inputs)] amp_tensors_vector_list = [] amp_tensors_vector_optional_list = [] amp_autocast_list = [] @@ -1591,6 +1647,11 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): type_promote_inputs_call_list[pos] = f"new_{name}" else: type_promote_inputs_call_list[pos] = f"{name}" + if forward_api_name in type_autocast_op_list: + if name in type_autocast_op_list[forward_api_name]: + type_autocast_inputs_call_list[pos] = f"new_{name}" + else: + type_autocast_inputs_call_list[pos] = f"{name}" if IsPlainTensorType(ttype): if is_optional: if ( @@ -1682,6 +1743,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): inputs_call_list[pos] = name amp_inputs_call_list[pos] = name type_promote_inputs_call_list[pos] = name + type_autocast_inputs_call_list[pos] = name if default_val is not None: inputs_args_declaration_list[pos] = ( f"{atype} {name} = {default_val}" @@ -1971,6 +2033,31 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): ) else: type_promotion_logic_str = f'\n VLOG(5) << " No Type Promotion for {forward_ad_function_name} api. "; ' + + # Forward type autocast logic + if forward_api_name in type_autocast_op_list: + # only support one inputs + op_func_name = f'"{forward_api_name}"' + x = type_autocast_op_list[forward_api_name][0] + type_autocast_inputs_call_args_str = ", ".join( + type_autocast_inputs_call_list + ) + trace_backward = ( + forward_api_name in type_autocast_valid_grad_op_list + ) and (not self.is_forward_only) + trace_backward = str(trace_backward).lower() + type_autocast_call_list = f"return {forward_ad_function_name}({type_autocast_inputs_call_args_str});" + + type_autocast_logic_str = TYPE_AUTOCAST_LOGIC_TEMPLATE.format( + op_func_name=op_func_name, + x=x, + op_name=kernel_trans2_op_name_str, + trace_backward=trace_backward, + return_value=type_autocast_call_list, + ) + else: + type_autocast_logic_str = f'\n VLOG(5) << " No Type Autocast for {forward_ad_function_name} api. "; ' + # Forward layout autotune layout_autotune_list_str = " ".join( layout_autotune_list @@ -2020,6 +2107,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): dygraph_event_str, amp_logic_str, type_promotion_logic_str, + type_autocast_logic_str, layout_logic_str, forward_api_name, before_log_str, @@ -2044,6 +2132,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): dygraph_event_str, amp_logic_str, type_promotion_logic_str, + type_autocast_logic_str, layout_logic_str, inputs_autograd_meta_str, forward_api_name, diff --git a/paddle/fluid/pir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py index 9240e8205e703..f32ccb6d40e93 100644 --- a/paddle/fluid/pir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -72,6 +72,48 @@ "remainder_": ["x", "y"], } +# ops support casting int tensor into float32 to do forward calculation +type_autocast_op_list = { + "acos": ["x"], + "acosh": ["x"], + "asin": ["x"], + "asinh": ["x"], + "atan": ["x"], + "atanh": ["x"], + "ceil": ["x"], + "cos": ["x"], + "cosh": ["x"], + "digamma": ["x"], + "erf": ["x"], + "erfinv": ["x"], + "floor": ["x"], + "i0": ["x"], + "i0e": ["x"], + "i1": ["x"], + "i1e": ["x"], + "lgamma": ["x"], + "logcumsumexp": ["x"], + "logit": ["x"], + "logsumexp": ["x"], + "polygamma": ["x"], + "reciprocal": ["x"], + "rsqrt": ["x"], + "sigmoid": ["x"], + "sin": ["x"], + "sinh": ["x"], + "sqrt": ["x"], + "stanh": ["x"], + "tan": ["x"], + "tanh": ["x"], +} + +# ops support casting int tensor into float32 to do forward calculation, +# and it is valid to cast float32 gradient back to int tensor. +type_autocast_valid_grad_op_list = { + "ceil", + "floor", +} + PD_MANUAL_API_LIST = { 'embedding_grad', 'assign', @@ -140,6 +182,8 @@ {amp_logic} // Type Promotion Logic {type_promotion_logic} + // Type Autocast Logic + {type_autocast_logic} {check_data_type} {handle_optional_inputs} {in_combine} @@ -196,6 +240,18 @@ }} """ +TYPE_AUTOCAST_LOGIC_TEMPLATE = """ + auto x_dtype = paddle::imperative::GetDataType({x}); + if (phi::NeedTypeAutoCast("{op_name}", x_dtype)) {{ + VLOG(5) << "math operation got integer input data type, run type autocast."; + LOG_FIRST_N(WARNING, 1) << "math operation got integer input data type, run type autocast, this may cause data type been changed."; + //{op_name} + if (!{trace_backward}) {{ SetStopGradient({x}); }} + auto new_{x} = pir::PromoteCast("{x}", {x}, phi::DataType::FLOAT32); + return paddle::dialect::{op_name}({args}); + }} +""" + OP_DISPATCH_TEMPLATE = """ if ({cond}) {{ {inner_code} @@ -861,6 +917,44 @@ def _gen_type_promotion_logic(self, op_info, op_name): return type_promotion_logic_str + def _gen_type_autocast_args(self, op_info, op_name): + type_autocast_inputs_call_list = [] + for name in op_info.input_name_list: + if op_name in type_autocast_op_list: + if name in type_autocast_op_list[op_name]: + type_autocast_inputs_call_list.append(f"new_{name}") + else: + type_autocast_inputs_call_list.append(f"{name}") + + attr_list = op_info.attribute_name_list + args = type_autocast_inputs_call_list + attr_list + return ', '.join(args) + + def _gen_type_autocast_logic(self, op_info, op_name): + if op_name in type_autocast_op_list: + x = type_autocast_op_list[op_name][0] + + type_autocast_inputs_call_args_str = self._gen_type_autocast_args( + op_info, op_name + ) + trace_backward = op_name in type_autocast_valid_grad_op_list + trace_backward = str(trace_backward).lower() + + if op_info.is_sparse_op: + op_name += "sp_" if op_name[-1] == "_" else "_sp" + type_autocast_logic_str = TYPE_AUTOCAST_LOGIC_TEMPLATE.format( + op_name=op_name, + x=x, + trace_backward=trace_backward, + args=type_autocast_inputs_call_args_str, + ) + else: + type_autocast_logic_str = ( + f'\n VLOG(5) << " No Type Autocast for {op_name} api. "; ' + ) + + return type_autocast_logic_str + def _gen_check_data_type(self, op_info, op_name): mapping_input_name_to_type = dict( zip(op_info.input_name_list, op_info.input_type_list) @@ -1044,6 +1138,9 @@ def _gen_one_impl( type_promotion_logic=self._gen_type_promotion_logic( op_info, op_name ), + type_autocast_logic=self._gen_type_autocast_logic( + op_info, op_name + ), check_data_type=self._gen_check_data_type( op_info, kernel_name ), @@ -1109,6 +1206,9 @@ def _gen_one_impl( type_promotion_logic=self._gen_type_promotion_logic( op_info, op_name ), + type_autocast_logic=self._gen_type_autocast_logic( + op_info, op_name + ), check_data_type=self._gen_check_data_type(op_info, kernel_name), handle_optional_inputs=self._gen_handle_optional_inputs( op_info diff --git a/paddle/phi/common/type_promotion.h b/paddle/phi/common/type_promotion.h index 95e3ae2933312..707500b93d52c 100644 --- a/paddle/phi/common/type_promotion.h +++ b/paddle/phi/common/type_promotion.h @@ -94,6 +94,15 @@ static std::unordered_set support_promotion_ops = { "less_than", "less_equal", "greater_than", "greater_equal", }; +static std::unordered_set support_autocast_ops = { + "acos", "acosh", "asin", "asinh", "atan", "atanh", + "ceil", "cos", "cosh", "digamma", "erf", "erfinv", + "floor", "lgamma", "logcumsumexp", "logit", "logsumexp", "polygamma", + "reciprocal", "rsqrt", "sin", "sinh", "sqrt", "stanh", + "tan", "tanh", "i0", "i0e", "i1", "i1e", + "sigmoid", +}; + inline bool is_support_float(DataType dtype) { if (dtype == DataType::FLOAT16 || dtype == DataType::FLOAT32 || dtype == DataType::FLOAT64 || dtype == DataType::BFLOAT16) { @@ -264,4 +273,12 @@ inline bool NeedTypePromotionOldIr(const std::string& op_name, } } +inline bool NeedTypeAutoCast(const std::string& op_name, + const DataType& x_dtype) { + if (support_autocast_ops.find(op_name) != support_autocast_ops.end() && + (is_support_int(x_dtype))) { + return true; + } + return false; +} } // namespace phi diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index cd55d1372a42d..4cba41372e833 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -350,13 +350,14 @@ def stanh( out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}} Parameters: - x (Tensor): The input Tensor with data type float32, float64. + x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64, + uint8, int8, int16, int32, int64. scale_a (float, optional): The scale factor a of the input. Default is 0.67. scale_b (float, optional): The scale factor b of the output. Default is 1.7159. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - A Tensor with the same data type and shape as ``x`` . + A Tensor with the same shape and data type as ``x`` (integer types are autocasted into float32). Examples: .. code-block:: python @@ -375,7 +376,20 @@ def stanh( return _C_ops.stanh(x, scale_a, scale_b) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'stanh' + x, + 'x', + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + 'stanh', ) helper = LayerHelper('stanh', **locals()) @@ -2760,8 +2774,9 @@ def logsumexp( logsumexp(x) = \log\sum exp(x) Args: - x (Tensor): The input Tensor with data type float16, float32 or float64, which - have no more than 4 dimensions. + x (Tensor): The input Tensor with data type bfloat16, float16, float32, + float64, uint8, int8, int16, int32, int64, which have no more than + 4 dimensions. axis (int|list|tuple|None, optional): The axis along which to perform logsumexp calculations. ``axis`` should be int, list(int) or tuple(int). If ``axis`` is a list/tuple of dimension(s), logsumexp @@ -2781,7 +2796,7 @@ def logsumexp( Returns: Tensor, results of logsumexp along ``axis`` of ``x``, with the same data - type as ``x``. + type as ``x`` (integer types are autocasted into float32). Examples: @@ -2806,7 +2821,20 @@ def logsumexp( return _C_ops.logsumexp(x, axis, keepdim, reduce_all) else: check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'logsumexp' + x, + 'x', + [ + 'float16', + 'float32', + 'float64', + 'uint16', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + 'logsumexp', ) helper = LayerHelper('logsumexp', **locals()) @@ -4463,13 +4491,14 @@ def logcumsumexp( The first element of the result is the same as the first element of the input. Args: - x (Tensor): The input tensor. + x (Tensor): The input tensor, with data type float32, float64, float16, + bfloat16, uint8, int8, int16, int32, int64 axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. dtype (str|paddle.dtype|np.dtype, optional): The data type of the output tensor, can be float16, float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor, the result of logcumsumexp operator. + Tensor, the result of logcumsumexp operator (integer input types are autocasted into float32). Examples: .. code-block:: python @@ -4516,7 +4545,20 @@ def logcumsumexp( return _C_ops.logcumsumexp(x, axis, flatten, False, False) else: check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'uint16'], "logcumsumexp" + x, + 'x', + [ + 'float16', + 'float32', + 'float64', + 'uint16', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + "logcumsumexp", ) helper = LayerHelper('logcumsumexp', **locals()) @@ -4949,11 +4991,13 @@ def tanh(x: Tensor, name: str | None = None) -> Tensor: out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}} Args: - x (Tensor): Input of Tanh operator, an N-D Tensor, with data type bfloat16, float32, float64 or float16. + x (Tensor): Input of Tanh operator, an N-D Tensor, with data type bfloat16, float32, float64, + float16, uint8, int8, int16, int32, int64. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Output of Tanh operator, a Tensor with same data type and shape as input. + Output of Tanh operator, a Tensor with same data type and shape as input + (integer types are autocasted into float32). Examples: @@ -4971,7 +5015,20 @@ def tanh(x: Tensor, name: str | None = None) -> Tensor: return _C_ops.tanh(x) else: check_variable_and_dtype( - x, 'x', ['uint16', 'float16', 'float32', 'float64'], 'tanh' + x, + 'x', + [ + 'uint16', + 'float16', + 'float32', + 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + 'tanh', ) check_type(x, 'x', (Variable), 'tanh') helper = LayerHelper('tanh', **locals()) @@ -5364,10 +5421,12 @@ def digamma(x: Tensor, name: str | None = None) -> Tensor: Out = \Psi(x) = \frac{ \Gamma^{'}(x) }{ \Gamma(x) } Args: - x (Tensor): Input Tensor. Must be one of the following types: float32, float64. + x (Tensor): Input Tensor. Must be one of the following types: bfloat16, float16, float32, + float64, uint8, int8, int16, int32, int64. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor, the digamma of the input Tensor, the shape and data type is the same with input. + Tensor, the digamma of the input Tensor, the shape and data type is the same with input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -5386,7 +5445,20 @@ def digamma(x: Tensor, name: str | None = None) -> Tensor: return _C_ops.digamma(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'digamma' + x, + 'x', + [ + 'float16', + 'float32', + 'float64', + 'uint16', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + 'digamma', ) helper = LayerHelper('digamma', **locals()) out = helper.create_variable_for_type_inference(x.dtype) @@ -5516,11 +5588,13 @@ def lgamma(x: Tensor, name: str | None = None) -> Tensor: Args: - x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, uint16. + x (Tensor): Input Tensor. Must be one of the following types: bfloat16, float16, float32, float64, + uint8, int8, int16, int32, int64. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor, the lgamma of the input Tensor, the shape and data type is the same with input. + Tensor, the lgamma of the input Tensor, the shape and data type is the same with input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -5736,13 +5810,15 @@ def logit( \end{array}\right. Args: - x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64. + x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64, + uint8, int8, int16, int32, int64. eps (float|None, optional): the epsilon for input clamp bound. Default is None. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - out(Tensor): A Tensor with the same data type and shape as ``x`` . + out(Tensor): A Tensor with the same data type and shape as ``x`` + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -5762,7 +5838,20 @@ def logit( return _C_ops.logit(x, eps) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'logit' + x, + 'x', + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + 'logit', ) helper = LayerHelper("logit", **locals()) out = helper.create_variable_for_type_inference(x.dtype) @@ -5879,11 +5968,13 @@ def erfinv(x: Tensor, name: str | None = None) -> Tensor: erfinv(erf(x)) = x. Args: - x (Tensor): An N-D Tensor, the data type is float16, bfloat16, float32, float64. + x (Tensor): An N-D Tensor, the data type is float16, bfloat16, float32, float64, + uint8, int8, int16, int32, int64. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - out (Tensor), an N-D Tensor, the shape and data type is the same with input. + out (Tensor), an N-D Tensor, the shape and data type is the same with input + (integer types are autocasted into float32). Example: .. code-block:: python @@ -5901,7 +5992,20 @@ def erfinv(x: Tensor, name: str | None = None) -> Tensor: return _C_ops.erfinv(x) else: check_variable_and_dtype( - x, 'x', ['float32', 'float64', 'float16', 'uint16'], 'erfinv' + x, + 'x', + [ + 'float32', + 'float64', + 'float16', + 'uint16', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + 'erfinv', ) helper = LayerHelper('erfinv', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -7267,11 +7371,13 @@ def i0(x: Tensor, name: str | None = None) -> Tensor: I_0(x) = \sum^{\infty}_{k=0}\frac{(x^2/4)^k}{(k!)^2} Args: - x (Tensor): The input tensor, it's data type should be float32, float64. + x (Tensor): The input tensor, it's data type should be float32, float64, + uint8, int8, int16, int32, int64. name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: - - out (Tensor), A Tensor. the value of the modified bessel function of order 0 at x. + - out (Tensor), A Tensor. the value of the modified bessel function of order 0 at x + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -7286,7 +7392,12 @@ def i0(x: Tensor, name: str | None = None) -> Tensor: if in_dynamic_or_pir_mode(): return _C_ops.i0(x) else: - check_variable_and_dtype(x, "x", ["float32", "float64"], "i0") + check_variable_and_dtype( + x, + "x", + ["float32", "float64", "uint8", "int8", "int16", "int32", "int64"], + "i0", + ) helper = LayerHelper("i0", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -7316,11 +7427,13 @@ def i0e(x: Tensor, name: str | None = None) -> Tensor: I_{0e}(x) = e^{-|x|}I_0(x) Args: - x (Tensor): The input tensor, it's data type should be float32, float64. + x (Tensor): The input tensor, it's data type should be float32, float64, + uint8, int8, int16, int32, int64. name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: - - out (Tensor), A Tensor. the value of the exponentially scaled modified Bessel function of order 0 at x. + - out (Tensor), A Tensor. the value of the exponentially scaled modified Bessel function of order 0 at x + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -7335,7 +7448,12 @@ def i0e(x: Tensor, name: str | None = None) -> Tensor: if in_dynamic_or_pir_mode(): return _C_ops.i0e(x) else: - check_variable_and_dtype(x, "x", ["float32", "float64"], "i0e") + check_variable_and_dtype( + x, + "x", + ["float32", "float64", "uint8", "int8", "int16", "int32", "int64"], + "i0e", + ) helper = LayerHelper("i0e", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -7348,11 +7466,13 @@ def i1(x: Tensor, name: str | None = None) -> Tensor: The function is used to calculate modified bessel function of order 1. Args: - x (Tensor): The input tensor, it's data type should be float32, float64. + x (Tensor): The input tensor, it's data type should be float32, float64, + uint8, int8, int16, int32, int64. name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: - - out (Tensor), A Tensor. the value of the modified bessel function of order 1 at x. + - out (Tensor), A Tensor. the value of the modified bessel function of order 1 at x + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -7367,7 +7487,12 @@ def i1(x: Tensor, name: str | None = None) -> Tensor: if in_dynamic_or_pir_mode(): return _C_ops.i1(x) else: - check_variable_and_dtype(x, "x", ["float32", "float64"], "i1") + check_variable_and_dtype( + x, + "x", + ["float32", "float64", "uint8", "int8", "int16", "int32", "int64"], + "i1", + ) helper = LayerHelper("i1", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -7383,11 +7508,13 @@ def i1e(x: Tensor, name: str | None = None) -> Tensor: Args: - x (Tensor): The input tensor, it's data type should be float32, float64. + x (Tensor): The input tensor, it's data type should be float32, float64, + uint8, int8, int16, int32, int64. name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: - - out (Tensor), A Tensor. the value of the exponentially scaled modified Bessel function of order 1 at x. + - out (Tensor), A Tensor. the value of the exponentially scaled modified Bessel function of order 1 at x + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -7402,7 +7529,12 @@ def i1e(x: Tensor, name: str | None = None) -> Tensor: if in_dynamic_or_pir_mode(): return _C_ops.i1e(x) else: - check_variable_and_dtype(x, "x", ["float32", "float64"], "i1e") + check_variable_and_dtype( + x, + "x", + ["float32", "float64", "uint8", "int8", "int16", "int32", "int64"], + "i1e", + ) helper = LayerHelper("i1e", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -7422,12 +7554,14 @@ def polygamma(x: Tensor, n: int, name: str | None = None) -> Tensor: \Phi^n(x) = \frac{d^n}{dx^n} [\ln(\Gamma(x))] Args: - x (Tensor): Input Tensor. Must be one of the following types: float32, float64. + x (Tensor): Input Tensor. Must be one of the following types: float32, float64, + uint8, int8, int16, int32, int64. n (int): Order of the derivative. Must be integral. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - - out (Tensor), A Tensor. the polygamma of the input Tensor, the shape and data type is the same with input. + - out (Tensor), A Tensor. the polygamma of the input Tensor, the shape and data type is the same with input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -7455,7 +7589,18 @@ def polygamma(x: Tensor, n: int, name: str | None = None) -> Tensor: return _C_ops.polygamma(x, n) else: check_variable_and_dtype( - x, "x", ["float32", "float64"], "polygamma" + x, + "x", + [ + "float32", + "float64", + "uint8", + "int8", + "int16", + "int32", + "int64", + ], + "polygamma", ) helper = LayerHelper("polygamma", **locals()) diff --git a/python/paddle/tensor/ops.py b/python/paddle/tensor/ops.py index 5d5dc90f943c4..90fbeb7a2e719 100644 --- a/python/paddle/tensor/ops.py +++ b/python/paddle/tensor/ops.py @@ -105,11 +105,13 @@ def acos(x: Tensor, name: str | None = None) -> Tensor: out = cos^{-1}(x) Args: - x (Tensor): Input of Acos operator, an N-D Tensor, with data type float32, float64, float16, complex64 or complex128. + x (Tensor): Input of Acos operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Acos operator, a Tensor with shape same as input. + Tensor. Output of Acos operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -133,6 +135,11 @@ def acos(x: Tensor, name: str | None = None) -> Tensor: 'uint16', 'float32', 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -152,11 +159,13 @@ def acosh(x: Tensor, name: str | None = None) -> Tensor: out = acosh(x) Args: - x (Tensor): Input of Acosh operator, an N-D Tensor, with data type float32, float64, float16, complex64 or complex128. + x (Tensor): Input of Acosh operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Acosh operator, a Tensor with shape same as input. + Tensor. Output of Acosh operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -180,6 +189,11 @@ def acosh(x: Tensor, name: str | None = None) -> Tensor: 'uint16', 'float32', 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -199,11 +213,13 @@ def asin(x: Tensor, name: str | None = None) -> Tensor: out = sin^{-1}(x) Args: - x (Tensor): Input of Asin operator, an N-D Tensor, with data type float32, float64, float16, complex64 or complex128. + x (Tensor): Input of Asin operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Same shape and dtype as input. + Tensor. Same shape and dtype as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -227,6 +243,11 @@ def asin(x: Tensor, name: str | None = None) -> Tensor: 'uint16', 'float32', 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -246,11 +267,13 @@ def asinh(x: Tensor, name: str | None = None) -> Tensor: out = asinh(x) Args: - x (Tensor): Input of Asinh operator, an N-D Tensor, with data type float32, float64, float16, complex64 or complex128. + x (Tensor): Input of Asinh operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Asinh operator, a Tensor with shape same as input. + Tensor. Output of Asinh operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -274,6 +297,11 @@ def asinh(x: Tensor, name: str | None = None) -> Tensor: 'uint16', 'float32', 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -293,11 +321,13 @@ def atan(x: Tensor, name: str | None = None) -> Tensor: out = tan^{-1}(x) Args: - x (Tensor): Input of Atan operator, an N-D Tensor, with data type float32, float64, float16, complex64 or complex128. + x (Tensor): Input of Atan operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Same shape and dtype as input x. + Tensor. Same shape and dtype as input x + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -321,6 +351,11 @@ def atan(x: Tensor, name: str | None = None) -> Tensor: 'uint16', 'float32', 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -340,11 +375,13 @@ def atanh(x: Tensor, name: str | None = None) -> Tensor: out = atanh(x) Args: - x (Tensor): Input of Atan operator, an N-D Tensor, with data type float32, float64, float16, complex64 or complex128. + x (Tensor): Input of Atan operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Atanh operator, a Tensor with shape same as input. + Tensor. Output of Atanh operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -368,6 +405,11 @@ def atanh(x: Tensor, name: str | None = None) -> Tensor: 'uint16', 'float32', 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -388,11 +430,13 @@ def ceil(x: Tensor, name: str | None = None) -> Tensor: out = \\left \\lceil x \\right \\rceil Args: - x (Tensor): Input of Ceil operator, an N-D Tensor, with data type float32, float64 or float16. + x (Tensor): Input of Ceil operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Ceil operator, a Tensor with shape same as input. + Tensor. Output of Ceil operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -409,7 +453,20 @@ def ceil(x: Tensor, name: str | None = None) -> Tensor: return _C_ops.ceil(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'ceil' + x, + 'x', + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + 'ceil', ) helper = LayerHelper('ceil', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -427,11 +484,13 @@ def cos(x: Tensor, name: str | None = None) -> Tensor: out = cos(x) Args: - x (Tensor): Input of Cos operator, an N-D Tensor, with data type float32, float64 or float16. + x (Tensor): Input of Cos operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64, complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Cos operator, a Tensor with shape same as input. + Tensor. Output of Cos operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -450,7 +509,19 @@ def cos(x: Tensor, name: str | None = None) -> Tensor: check_variable_and_dtype( x, 'x', - ['float16', 'float32', 'float64', 'complex64', 'complex128'], + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + 'complex64', + 'complex128', + ], 'cos', ) helper = LayerHelper('cos', **locals()) @@ -469,11 +540,13 @@ def cosh(x: Tensor, name: str | None = None) -> Tensor: out = \\frac{exp(x)+exp(-x)}{2} Args: - x (Tensor): Input of Cosh operator, an N-D Tensor, with data type float32, float64, float16, complex64 or complex128. + x (Tensor): Input of Cosh operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Cosh operator, a Tensor with shape same as input. + Tensor. Output of Cosh operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -497,6 +570,11 @@ def cosh(x: Tensor, name: str | None = None) -> Tensor: 'uint16', 'float32', 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -617,11 +695,13 @@ def floor(x: Tensor, name: str | None = None) -> Tensor: out = \\lfloor x \\rfloor Args: - x (Tensor): Input of Floor operator, an N-D Tensor, with data type float32, float64 or float16. + x (Tensor): Input of Floor operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Floor operator, a Tensor with shape same as input. + Tensor. Output of Floor operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -638,7 +718,20 @@ def floor(x: Tensor, name: str | None = None) -> Tensor: return _C_ops.floor(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'floor' + x, + 'x', + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + 'floor', ) helper = LayerHelper('floor', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -655,11 +748,13 @@ def reciprocal(x: Tensor, name: str | None = None) -> Tensor: out = \\frac{1}{x} Args: - x (Tensor): Input of Reciprocal operator, an N-D Tensor, with data type float32, float64, float16, complex64 or complex128. + x (Tensor): Input of Reciprocal operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Reciprocal operator, a Tensor with shape same as input. + Tensor. Output of Reciprocal operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -676,7 +771,20 @@ def reciprocal(x: Tensor, name: str | None = None) -> Tensor: return _C_ops.reciprocal(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'reciprocal' + x, + 'x', + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + 'reciprocal', ) helper = LayerHelper('reciprocal', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -756,11 +864,13 @@ def rsqrt(x: Tensor, name: str | None = None) -> Tensor: out = \\frac{1}{\\sqrt{x}} Args: - x (Tensor): Input of Rsqrt operator, an N-D Tensor, with data type float32, float64 or float16. + x (Tensor): Input of Rsqrt operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Rsqrt operator, a Tensor with shape same as input. + Tensor. Output of Rsqrt operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -777,7 +887,20 @@ def rsqrt(x: Tensor, name: str | None = None) -> Tensor: return _C_ops.rsqrt(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'rsqrt' + x, + 'x', + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], + 'rsqrt', ) helper = LayerHelper('rsqrt', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -793,11 +916,13 @@ def sigmoid(x: Tensor, name: str | None = None) -> Tensor: out = \\frac{1}{1 + e^{-x}} Args: - x (Tensor): Input of Sigmoid operator, an N-D Tensor, with data type float16, float32, float64, complex64 or complex128. + x (Tensor): Input of Sigmoid operator, an N-D Tensor, with data type bfloat16, float16, float32, float64, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Sigmoid operator, a Tensor with shape same as input. + Tensor. Output of Sigmoid operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -822,6 +947,11 @@ def sigmoid(x: Tensor, name: str | None = None) -> Tensor: 'float32', 'float64', 'uint16', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -841,11 +971,13 @@ def sin(x: Tensor, name: str | None = None) -> Tensor: out = sin(x) Args: - x (Tensor): Input of Sin operator, an N-D Tensor, with data type float32, float64 or float16. + x (Tensor): Input of Sin operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Sin operator, a Tensor with shape same as input. + Tensor. Output of Sin operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -869,6 +1001,11 @@ def sin(x: Tensor, name: str | None = None) -> Tensor: 'uint16', 'float32', 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -888,11 +1025,13 @@ def sinh(x: Tensor, name: str | None = None) -> Tensor: out = sinh(x) Args: - x (Tensor): Input of Sinh operator, an N-D Tensor, with data type float32, float64, float16, complex64 or complex128. + x (Tensor): Input of Sinh operator, an N-D Tensor, with data type float32, float64, float16, bfloat16, + uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Sinh operator, a Tensor with shape same as input. + Tensor. Output of Sinh operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -916,6 +1055,11 @@ def sinh(x: Tensor, name: str | None = None) -> Tensor: 'uint16', 'float32', 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -935,11 +1079,13 @@ def sqrt(x: Tensor, name: str | None = None) -> Tensor: out=\\sqrt{x}=x^{1/2} Args: - x (Tensor): Input of Sqrt operator, an N-D Tensor, with data type float32, float64 or float16. + x (Tensor): Input of Sqrt operator, an N-D Tensor, with data type float32, float64, float16, bfloat16 + uint8, int8, int16, int32, int64. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Sqrt operator, a Tensor with shape same as input. + Tensor. Output of Sqrt operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -958,7 +1104,17 @@ def sqrt(x: Tensor, name: str | None = None) -> Tensor: check_variable_and_dtype( x, 'x', - ['float16', 'uint16', 'float32', 'float64'], + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + ], 'sqrt', ) helper = LayerHelper('sqrt', **locals()) @@ -1025,11 +1181,13 @@ def tan(x: Tensor, name: str | None = None) -> Tensor: out = tan(x) Args: - x (Tensor): Input of Tan operator, an N-D Tensor, with data type float32, float64 or float16. + x (Tensor): Input of Tan operator, an N-D Tensor, with data type float32, float64, float16, + bfloat16, uint8, int8, int16, int32, int64, complex64 or complex128. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor. Output of Tan operator, a Tensor with shape same as input. + Tensor. Output of Tan operator, a Tensor with shape same as input + (integer types are autocasted into float32). Examples: .. code-block:: python @@ -1053,6 +1211,11 @@ def tan(x: Tensor, name: str | None = None) -> Tensor: 'uint16', 'float32', 'float64', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', 'complex64', 'complex128', ], @@ -1074,11 +1237,11 @@ def erf(x: Tensor, name: str | None = None) -> Tensor: out = \frac{2}{\sqrt{\pi}} \int_{0}^{x}e^{- \eta^{2}}d\eta Args: - x (Tensor): The input tensor, it's data type should be float32, float64. + x (Tensor): The input tensor, it's data type should be float32, float64, uint8, int8, int16, int32, int64. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor: The output of Erf, dtype: float32 or float64, the same as the input, shape: the same as the input. + Tensor. The output of Erf, dtype: float32 or float64 (integer types are autocasted into float32), shape: the same as the input. Examples: diff --git a/test/legacy_test/test_tensor_type_autocast.py b/test/legacy_test/test_tensor_type_autocast.py new file mode 100644 index 0000000000000..865fc590bc159 --- /dev/null +++ b/test/legacy_test/test_tensor_type_autocast.py @@ -0,0 +1,214 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +class TestAutocastBase(unittest.TestCase): + def setUp(self): + self.set_api_and_dtypes() + self.places = [paddle.CPUPlace()] + if paddle.core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def set_api_and_dtypes(self): + pass + + +def create_test_case( + baseclass, + api, + support_int_types=["uint8", "int8", "int16", "int32", "int64"], + **kwargs, +): + class TestAutocast(baseclass): + def set_api_and_dtypes(self): + self.support_int_types = support_int_types + self.api = api + + def test_dygraph(self): + for place in self.places: + paddle.disable_static(place) + for type in self.support_int_types: + x = paddle.arange(-100, 100).astype(type) + x_float = x.astype("float32") + int_out = self.api(x, **kwargs) + float_out = self.api(x_float, **kwargs) + np.testing.assert_array_equal( + int_out.numpy(), float_out.numpy() + ) + + def test_static(self): + paddle.enable_static() + for place in self.places: + exe = paddle.static.Executor(place) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + for type in self.support_int_types: + with paddle.static.program_guard( + main_program, startup_program + ): + x = paddle.arange(-100, 100).astype(type) + x_float = x.astype("float32") + int_out = self.api(x, **kwargs) + float_out = self.api(x_float, **kwargs) + out = exe.run(fetch_list=[int_out, float_out]) + np.testing.assert_array_equal(out[0], out[1]) + paddle.disable_static(place) + + api_name = api.__name__ + cls_name = f"{baseclass.__name__}{api_name}" + TestAutocast.__name__ = cls_name + globals()[cls_name] = TestAutocast + + +def create_test_case_with_grad( + baseclass, + api, + support_int_types=["uint8", "int8", "int16", "int32", "int64"], + **kwargs, +): + class TestAutocastValidGrad(baseclass): + def set_api_and_dtypes(self): + self.support_int_types = support_int_types + self.api = api + + def test_dygraph(self): + for place in self.places: + paddle.disable_static(place) + for type in self.support_int_types: + x = paddle.arange(-100, 100).astype(type) + x_float = x.astype("float32") + x.stop_gradient = False + x_float.stop_gradient = False + int_out = self.api(x, **kwargs) + float_out = self.api(x_float, **kwargs) + int_out.backward() + float_out.backward() + np.testing.assert_array_equal( + int_out.numpy(), float_out.numpy() + ) + np.testing.assert_equal(x.grad.dtype, x.dtype) + np.testing.assert_allclose( + x.grad.numpy(), x_float.grad.numpy() + ) + + def test_static(self): + paddle.enable_static() + for place in self.places: + exe = paddle.static.Executor(place) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + for type in self.support_int_types: + with paddle.static.program_guard( + main_program, startup_program + ): + x = paddle.arange(-100, 100).astype(type) + x_float = x.astype("float32") + x.stop_gradient = False + x_float.stop_gradient = False + int_out = self.api(x, **kwargs) + float_out = self.api(x_float, **kwargs) + x_grad = paddle.static.gradients(int_out, x) + x_float_grad = paddle.static.gradients( + float_out, x_float + ) + out = exe.run( + fetch_list=[ + int_out, + float_out, + x_grad, + x_float_grad, + ] + ) + np.testing.assert_array_equal(out[0], out[1]) + np.testing.assert_equal(out[2].dtype, np.dtype(type)) + np.testing.assert_allclose(out[2], out[3]) + paddle.disable_static(place) + + api_name = api.__name__ + cls_name = f"{baseclass.__name__}{api_name}" + TestAutocastValidGrad.__name__ = cls_name + globals()[cls_name] = TestAutocastValidGrad + + +create_test_case(TestAutocastBase, paddle.acos) + +create_test_case(TestAutocastBase, paddle.acosh) + +create_test_case(TestAutocastBase, paddle.asin) + +create_test_case(TestAutocastBase, paddle.asinh) + +create_test_case(TestAutocastBase, paddle.atan) + +create_test_case(TestAutocastBase, paddle.atanh) + +create_test_case(TestAutocastBase, paddle.cos) + +create_test_case(TestAutocastBase, paddle.cosh) + +create_test_case(TestAutocastBase, paddle.digamma) + +create_test_case(TestAutocastBase, paddle.erf) + +create_test_case(TestAutocastBase, paddle.erfinv) + +create_test_case(TestAutocastBase, paddle.i0) + +create_test_case(TestAutocastBase, paddle.i0e) + +create_test_case(TestAutocastBase, paddle.i1) + +create_test_case(TestAutocastBase, paddle.i1e) + +create_test_case(TestAutocastBase, paddle.lgamma) + +create_test_case(TestAutocastBase, paddle.logcumsumexp) + +create_test_case(TestAutocastBase, paddle.logit) + +create_test_case(TestAutocastBase, paddle.logsumexp) + +create_test_case(TestAutocastBase, paddle.polygamma, n=1) + +create_test_case(TestAutocastBase, paddle.reciprocal) + +create_test_case(TestAutocastBase, paddle.rsqrt) + +create_test_case(TestAutocastBase, paddle.sin) + +create_test_case(TestAutocastBase, paddle.sinh) + +create_test_case(TestAutocastBase, paddle.nn.functional.sigmoid) + +create_test_case(TestAutocastBase, paddle.sqrt) + +create_test_case(TestAutocastBase, paddle.stanh) + +create_test_case(TestAutocastBase, paddle.tan) + +create_test_case(TestAutocastBase, paddle.tanh) + +create_test_case_with_grad(TestAutocastBase, paddle.ceil) + +create_test_case_with_grad(TestAutocastBase, paddle.floor) + + +if __name__ == '__main__': + unittest.main() From 73874b28bd33d9f8f607cbca7ab81acf44b2632a Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 8 Nov 2024 20:16:58 +0800 Subject: [PATCH 2/2] update test --- test/legacy_test/test_digamma_op.py | 4 ++-- test/legacy_test/test_logit_op.py | 2 +- test/legacy_test/test_logsumexp.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/legacy_test/test_digamma_op.py b/test/legacy_test/test_digamma_op.py index f39f887a7f385..367570c6cdfd1 100644 --- a/test/legacy_test/test_digamma_op.py +++ b/test/legacy_test/test_digamma_op.py @@ -146,13 +146,13 @@ def test_dtype_error(self): # in static graph mode with self.assertRaises(TypeError): with static.program_guard(static.Program()): - x = static.data(name="x", shape=self._shape, dtype="int32") + x = static.data(name="x", shape=self._shape, dtype="bool") out = paddle.digamma(x, name="digamma_res") # in dynamic mode with self.assertRaises(RuntimeError): with base.dygraph.guard(): - input = np.random.random(self._shape).astype("int32") + input = np.random.random(self._shape).astype("bool") input_t = paddle.to_tensor(input) res = paddle.digamma(input_t) diff --git a/test/legacy_test/test_logit_op.py b/test/legacy_test/test_logit_op.py index c47dcacbf9cad..5ab56c831c8c2 100644 --- a/test/legacy_test/test_logit_op.py +++ b/test/legacy_test/test_logit_op.py @@ -185,7 +185,7 @@ def test_check_api(self): def test_errors(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data(name='X1', shape=[100], dtype='int32') + x = paddle.static.data(name='X1', shape=[100], dtype='bool') self.assertRaises(TypeError, paddle.logit, x) x = paddle.static.data(name='X2', shape=[100], dtype='float32') diff --git a/test/legacy_test/test_logsumexp.py b/test/legacy_test/test_logsumexp.py index 714c086a6d8ba..1a7455c49a140 100644 --- a/test/legacy_test/test_logsumexp.py +++ b/test/legacy_test/test_logsumexp.py @@ -249,7 +249,7 @@ class TestLogsumexpError(unittest.TestCase): def test_errors(self): with paddle.static.program_guard(paddle.static.Program()): self.assertRaises(TypeError, paddle.logsumexp, 1) - x1 = paddle.static.data(name='x1', shape=[120], dtype="int32") + x1 = paddle.static.data(name='x1', shape=[120], dtype="bool") self.assertRaises(TypeError, paddle.logsumexp, x1)