From 5f70fcdf1666339fa77847aa8115c11713135ff8 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 1 Feb 2024 20:44:56 +0800 Subject: [PATCH] [PIR+CINN]All CINN Subgraph UT Support CheckJitKernelInfo (#61431) * [PIR+CINN]All CINN Subgraph UT Support CheckJitKernelInfo * fix conflict --- .../cinn/adt/test_cinn_sub_graph_map_expr.py | 21 +++++----- .../symbolic/test_cinn_broadcast_symbolic.py | 24 +++++------ .../test_cinn_reduce_symbolic_demo.py | 26 ++++++------ .../symbolic/test_cinn_sub_graph_symbolic.py | 40 +++++++++++-------- .../symbolic/test_sub_graph_for_backend.py | 22 +++++----- .../symbolic/test_sub_graph_for_frontend.py | 26 ++++++------ 6 files changed, 88 insertions(+), 71 deletions(-) diff --git a/test/ir/pir/cinn/adt/test_cinn_sub_graph_map_expr.py b/test/ir/pir/cinn/adt/test_cinn_sub_graph_map_expr.py index 71a2145b5203f..e0fe9d3f8c148 100644 --- a/test/ir/pir/cinn/adt/test_cinn_sub_graph_map_expr.py +++ b/test/ir/pir/cinn/adt/test_cinn_sub_graph_map_expr.py @@ -11,22 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +from os.path import dirname + +sys.path.append(dirname(dirname(__file__))) import unittest import numpy as np +import utils import paddle -def apply_to_static(net, use_cinn): - build_strategy = paddle.static.BuildStrategy() - build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static( - net, build_strategy=build_strategy, full_graph=True - ) - - def exp_sub(x): y = paddle.exp(x) z = y - x @@ -58,12 +55,18 @@ def prepare_data(self): self.x = paddle.randn(self.shape, dtype="float32") self.x.stop_gradient = False + def check_jit_kernel_info(self, static_fn): + utils.check_jit_kernel_number(static_fn, 1) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + def train(self, use_cinn): paddle.seed(2022) net = CINNSubGraphNet() - net = apply_to_static(net, use_cinn) + net = utils.apply_to_static(net, use_cinn) net.eval() out = net(self.x) + if use_cinn: + self.check_jit_kernel_info(net.forward) return out def test_forward(self): diff --git a/test/ir/pir/cinn/symbolic/test_cinn_broadcast_symbolic.py b/test/ir/pir/cinn/symbolic/test_cinn_broadcast_symbolic.py index 4de79ccf097d2..63009a9704d7c 100644 --- a/test/ir/pir/cinn/symbolic/test_cinn_broadcast_symbolic.py +++ b/test/ir/pir/cinn/symbolic/test_cinn_broadcast_symbolic.py @@ -11,26 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +from os.path import dirname + +sys.path.append(dirname(dirname(__file__))) import unittest import numpy as np +import utils import paddle from paddle.static import InputSpec -def apply_to_static(net, use_cinn, input_spec=None): - build_strategy = paddle.static.BuildStrategy() - build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static( - net, - input_spec=input_spec, - build_strategy=build_strategy, - full_graph=True, - ) - - def broadcast_add(x, y): return paddle.exp(x) - y @@ -62,6 +56,10 @@ def prepare_data(self): self.y = paddle.randn(self.y_shape, dtype="float32") self.y.stop_gradient = False + def check_jit_kernel_info(self, static_fn): + utils.check_jit_kernel_number(static_fn, 1) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + def eval_symbolic(self, use_cinn): paddle.seed(2022) net = CINNSubGraphNet() @@ -69,9 +67,11 @@ def eval_symbolic(self, use_cinn): InputSpec(shape=[None, 128], dtype='float32'), InputSpec(shape=[None, 128], dtype='float32'), ] - net = apply_to_static(net, use_cinn, input_spec) + net = utils.apply_to_static(net, use_cinn, input_spec) net.eval() out = net(self.x, self.y) + if use_cinn: + self.check_jit_kernel_info(net.forward) return out def test_eval_symolic(self): diff --git a/test/ir/pir/cinn/symbolic/test_cinn_reduce_symbolic_demo.py b/test/ir/pir/cinn/symbolic/test_cinn_reduce_symbolic_demo.py index cb4f93d9be073..bb2e1c789e22f 100644 --- a/test/ir/pir/cinn/symbolic/test_cinn_reduce_symbolic_demo.py +++ b/test/ir/pir/cinn/symbolic/test_cinn_reduce_symbolic_demo.py @@ -11,24 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +from os.path import dirname + +sys.path.append(dirname(dirname(__file__))) import unittest +import utils + import paddle from paddle.static import InputSpec -def apply_to_static(net, use_cinn, input_spec=None): - build_strategy = paddle.static.BuildStrategy() - build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static( - net, - input_spec=input_spec, - build_strategy=build_strategy, - full_graph=True, - ) - - def reduce_sum(x): return paddle.sum(x, axis=-1) @@ -57,15 +52,22 @@ def prepare_data(self): self.x = paddle.randn(self.x_shape, dtype="float32") self.x.stop_gradient = False + def check_jit_kernel_info(self, static_fn): + utils.check_jit_kernel_number(static_fn, 1) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + def eval_symbolic(self, use_cinn): paddle.seed(2022) net = CINNSubGraphNet() input_spec = [ InputSpec(shape=[None, 128], dtype='float32'), ] - net = apply_to_static(net, use_cinn, input_spec) + net = utils.apply_to_static(net, use_cinn, input_spec) net.eval() out = net(self.x) + if use_cinn: + self.check_jit_kernel_info(net.forward) + return out def test_eval_symolic(self): diff --git a/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py b/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py index d8ba056f1fe44..b5efe5685e29a 100644 --- a/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py +++ b/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py @@ -11,10 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +from os.path import dirname + +sys.path.append(dirname(dirname(__file__))) import unittest import numpy as np +import utils import paddle from paddle.static import InputSpec @@ -32,17 +37,6 @@ def get_sym_shape_str_for_op(net, input_spec, op_name): return all_sym_shape_str -def apply_to_static(net, use_cinn, input_spec=None): - build_strategy = paddle.static.BuildStrategy() - build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static( - net, - input_spec=input_spec, - build_strategy=build_strategy, - full_graph=True, - ) - - def exp_sub(x): y = paddle.exp(x) z = y - x @@ -107,6 +101,10 @@ def prepare_data(self): self.x = paddle.randn(self.shape, dtype="float32") self.x.stop_gradient = False + def check_jit_kernel_info(self, static_fn): + utils.check_jit_kernel_number(static_fn, 1) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + def test_eval_symbolic(self): pass @@ -115,9 +113,11 @@ class TestCinnExpSubGraph(TestCinnSubGraphBase): def eval_symbolic(self, use_cinn): net = CINNSubGraphNet() input_spec = [InputSpec(shape=[None, 128], dtype='float32')] - net = apply_to_static(net, use_cinn, input_spec) + net = utils.apply_to_static(net, use_cinn, input_spec) net.eval() out = net(self.x) + if use_cinn: + self.check_jit_kernel_info(net.forward) return out def test_eval_symbolic(self): @@ -136,9 +136,11 @@ def eval_symbolic(self, use_cinn): paddle.seed(2022) net = CINNReshapeSubGraphNet() input_spec = [InputSpec(shape=[None, 256], dtype='float32')] - net = apply_to_static(net, use_cinn, input_spec) + net = utils.apply_to_static(net, use_cinn, input_spec) net.eval() out = net(self.x) + if use_cinn: + self.check_jit_kernel_info(net.forward) return out def test_eval_symbolic(self): @@ -164,9 +166,11 @@ def eval_symbolic(self, use_cinn): InputSpec(shape=[None, None, None], dtype='float32'), InputSpec(shape=[None, None], dtype='float32'), ] - net = apply_to_static(net, use_cinn, input_spec) + net = utils.apply_to_static(net, use_cinn, input_spec) net.eval() out = net(self.x, self.y) + if use_cinn: + self.check_jit_kernel_info(net.forward) return out def test_eval_symbolic(self): @@ -230,7 +234,7 @@ def eval_symbolic(self, use_cinn): input_spec = [ InputSpec(shape=[None, None, 4096], dtype='float32'), ] - net = apply_to_static(net, use_cinn, input_spec) + net = utils.apply_to_static(net, use_cinn, input_spec) net.eval() sym_shape_str_list = get_sym_shape_str_for_op( @@ -244,6 +248,8 @@ def eval_symbolic(self, use_cinn): ) out = net(self.hidden_states) + if use_cinn: + self.check_jit_kernel_info(net.forward) return out @@ -303,7 +309,7 @@ def eval_symbolic(self, use_cinn): input_spec = [ InputSpec(shape=[None, None, 8, 96], dtype='float32'), ] - net = apply_to_static(net, use_cinn, input_spec) + net = utils.apply_to_static(net, use_cinn, input_spec) net.eval() sym_shape_str_list = get_sym_shape_str_for_op( @@ -317,6 +323,8 @@ def eval_symbolic(self, use_cinn): ) out = net(self.hidden_states) + if use_cinn: + self.check_jit_kernel_info(net.forward) return out def test_eval_symbolic(self): diff --git a/test/ir/pir/cinn/symbolic/test_sub_graph_for_backend.py b/test/ir/pir/cinn/symbolic/test_sub_graph_for_backend.py index fbac99d77c244..52e401568adf4 100644 --- a/test/ir/pir/cinn/symbolic/test_sub_graph_for_backend.py +++ b/test/ir/pir/cinn/symbolic/test_sub_graph_for_backend.py @@ -11,21 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +from os.path import dirname + +sys.path.append(dirname(dirname(__file__))) + import unittest import numpy as np +import utils import paddle -def apply_to_static(net, use_cinn): - build_strategy = paddle.static.BuildStrategy() - build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static( - net, build_strategy=build_strategy, full_graph=True - ) - - def exp_sub(x): y = paddle.exp(x) z = y - x @@ -58,12 +56,18 @@ def prepare_data(self): self.x = paddle.randn(self.shape, dtype="float32") self.x.stop_gradient = False + def check_jit_kernel_info(self, static_fn): + utils.check_jit_kernel_number(static_fn, 1) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + def eval(self, use_cinn): paddle.seed(2022) net = CINNSubGraphNet() - net = apply_to_static(net, use_cinn) + net = utils.apply_to_static(net, use_cinn) net.eval() out = net(self.x) + if use_cinn: + self.check_jit_kernel_info(net.forward) return out def test_eval(self): diff --git a/test/ir/pir/cinn/symbolic/test_sub_graph_for_frontend.py b/test/ir/pir/cinn/symbolic/test_sub_graph_for_frontend.py index 52e1b0e65fb2d..a25b6a4d1d275 100644 --- a/test/ir/pir/cinn/symbolic/test_sub_graph_for_frontend.py +++ b/test/ir/pir/cinn/symbolic/test_sub_graph_for_frontend.py @@ -11,25 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +from os.path import dirname + +sys.path.append(dirname(dirname(__file__))) + import unittest import numpy as np +import utils import paddle from paddle.static import InputSpec -def apply_to_static(net, use_cinn, input_spec=None): - build_strategy = paddle.static.BuildStrategy() - build_strategy.build_cinn_pass = use_cinn - return paddle.jit.to_static( - net, - input_spec=input_spec, - build_strategy=build_strategy, - full_graph=True, - ) - - def exp_sub(x): y = paddle.exp(x) z = y - x @@ -62,16 +57,21 @@ def prepare_data(self): self.x = paddle.randn(self.shape, dtype="float32") self.x.stop_gradient = False + def check_jit_kernel_info(self, static_fn): + utils.check_jit_kernel_number(static_fn, 1) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + def eval(self, use_cinn): paddle.seed(2022) net = CINNSubGraphNet() input_spec = [ InputSpec(shape=[None, 96], dtype='float32'), ] - if use_cinn: - net = apply_to_static(net, use_cinn, input_spec) + net = utils.apply_to_static(net, use_cinn, input_spec) net.eval() out = net(self.x) + if use_cinn: + self.check_jit_kernel_info(net.forward) return out def test_eval(self):