Skip to content

Commit

Permalink
[PIR+CINN]All CINN Subgraph UT Support CheckJitKernelInfo (#61431)
Browse files Browse the repository at this point in the history
* [PIR+CINN]All CINN Subgraph UT Support CheckJitKernelInfo

* fix conflict
  • Loading branch information
Aurelius84 authored Feb 1, 2024
1 parent ccb2835 commit 5f70fcd
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 71 deletions.
21 changes: 12 additions & 9 deletions test/ir/pir/cinn/adt/test_cinn_sub_graph_map_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from os.path import dirname

sys.path.append(dirname(dirname(__file__)))

import unittest

import numpy as np
import utils

import paddle


def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(
net, build_strategy=build_strategy, full_graph=True
)


def exp_sub(x):
y = paddle.exp(x)
z = y - x
Expand Down Expand Up @@ -58,12 +55,18 @@ def prepare_data(self):
self.x = paddle.randn(self.shape, dtype="float32")
self.x.stop_gradient = False

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})

def train(self, use_cinn):
paddle.seed(2022)
net = CINNSubGraphNet()
net = apply_to_static(net, use_cinn)
net = utils.apply_to_static(net, use_cinn)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_forward(self):
Expand Down
24 changes: 12 additions & 12 deletions test/ir/pir/cinn/symbolic/test_cinn_broadcast_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from os.path import dirname

sys.path.append(dirname(dirname(__file__)))

import unittest

import numpy as np
import utils

import paddle
from paddle.static import InputSpec


def apply_to_static(net, use_cinn, input_spec=None):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(
net,
input_spec=input_spec,
build_strategy=build_strategy,
full_graph=True,
)


def broadcast_add(x, y):
return paddle.exp(x) - y

Expand Down Expand Up @@ -62,16 +56,22 @@ def prepare_data(self):
self.y = paddle.randn(self.y_shape, dtype="float32")
self.y.stop_gradient = False

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})

def eval_symbolic(self, use_cinn):
paddle.seed(2022)
net = CINNSubGraphNet()
input_spec = [
InputSpec(shape=[None, 128], dtype='float32'),
InputSpec(shape=[None, 128], dtype='float32'),
]
net = apply_to_static(net, use_cinn, input_spec)
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x, self.y)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval_symolic(self):
Expand Down
26 changes: 14 additions & 12 deletions test/ir/pir/cinn/symbolic/test_cinn_reduce_symbolic_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from os.path import dirname

sys.path.append(dirname(dirname(__file__)))

import unittest

import utils

import paddle
from paddle.static import InputSpec


def apply_to_static(net, use_cinn, input_spec=None):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(
net,
input_spec=input_spec,
build_strategy=build_strategy,
full_graph=True,
)


def reduce_sum(x):
return paddle.sum(x, axis=-1)

Expand Down Expand Up @@ -57,15 +52,22 @@ def prepare_data(self):
self.x = paddle.randn(self.x_shape, dtype="float32")
self.x.stop_gradient = False

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})

def eval_symbolic(self, use_cinn):
paddle.seed(2022)
net = CINNSubGraphNet()
input_spec = [
InputSpec(shape=[None, 128], dtype='float32'),
]
net = apply_to_static(net, use_cinn, input_spec)
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)

return out

def test_eval_symolic(self):
Expand Down
40 changes: 24 additions & 16 deletions test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from os.path import dirname

sys.path.append(dirname(dirname(__file__)))

import unittest

import numpy as np
import utils

import paddle
from paddle.static import InputSpec
Expand All @@ -32,17 +37,6 @@ def get_sym_shape_str_for_op(net, input_spec, op_name):
return all_sym_shape_str


def apply_to_static(net, use_cinn, input_spec=None):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(
net,
input_spec=input_spec,
build_strategy=build_strategy,
full_graph=True,
)


def exp_sub(x):
y = paddle.exp(x)
z = y - x
Expand Down Expand Up @@ -107,6 +101,10 @@ def prepare_data(self):
self.x = paddle.randn(self.shape, dtype="float32")
self.x.stop_gradient = False

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})

def test_eval_symbolic(self):
pass

Expand All @@ -115,9 +113,11 @@ class TestCinnExpSubGraph(TestCinnSubGraphBase):
def eval_symbolic(self, use_cinn):
net = CINNSubGraphNet()
input_spec = [InputSpec(shape=[None, 128], dtype='float32')]
net = apply_to_static(net, use_cinn, input_spec)
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval_symbolic(self):
Expand All @@ -136,9 +136,11 @@ def eval_symbolic(self, use_cinn):
paddle.seed(2022)
net = CINNReshapeSubGraphNet()
input_spec = [InputSpec(shape=[None, 256], dtype='float32')]
net = apply_to_static(net, use_cinn, input_spec)
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval_symbolic(self):
Expand All @@ -164,9 +166,11 @@ def eval_symbolic(self, use_cinn):
InputSpec(shape=[None, None, None], dtype='float32'),
InputSpec(shape=[None, None], dtype='float32'),
]
net = apply_to_static(net, use_cinn, input_spec)
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x, self.y)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval_symbolic(self):
Expand Down Expand Up @@ -230,7 +234,7 @@ def eval_symbolic(self, use_cinn):
input_spec = [
InputSpec(shape=[None, None, 4096], dtype='float32'),
]
net = apply_to_static(net, use_cinn, input_spec)
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()

sym_shape_str_list = get_sym_shape_str_for_op(
Expand All @@ -244,6 +248,8 @@ def eval_symbolic(self, use_cinn):
)

out = net(self.hidden_states)
if use_cinn:
self.check_jit_kernel_info(net.forward)

return out

Expand Down Expand Up @@ -303,7 +309,7 @@ def eval_symbolic(self, use_cinn):
input_spec = [
InputSpec(shape=[None, None, 8, 96], dtype='float32'),
]
net = apply_to_static(net, use_cinn, input_spec)
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()

sym_shape_str_list = get_sym_shape_str_for_op(
Expand All @@ -317,6 +323,8 @@ def eval_symbolic(self, use_cinn):
)

out = net(self.hidden_states)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval_symbolic(self):
Expand Down
22 changes: 13 additions & 9 deletions test/ir/pir/cinn/symbolic/test_sub_graph_for_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from os.path import dirname

sys.path.append(dirname(dirname(__file__)))

import unittest

import numpy as np
import utils

import paddle


def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(
net, build_strategy=build_strategy, full_graph=True
)


def exp_sub(x):
y = paddle.exp(x)
z = y - x
Expand Down Expand Up @@ -58,12 +56,18 @@ def prepare_data(self):
self.x = paddle.randn(self.shape, dtype="float32")
self.x.stop_gradient = False

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})

def eval(self, use_cinn):
paddle.seed(2022)
net = CINNSubGraphNet()
net = apply_to_static(net, use_cinn)
net = utils.apply_to_static(net, use_cinn)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval(self):
Expand Down
26 changes: 13 additions & 13 deletions test/ir/pir/cinn/symbolic/test_sub_graph_for_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from os.path import dirname

sys.path.append(dirname(dirname(__file__)))

import unittest

import numpy as np
import utils

import paddle
from paddle.static import InputSpec


def apply_to_static(net, use_cinn, input_spec=None):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(
net,
input_spec=input_spec,
build_strategy=build_strategy,
full_graph=True,
)


def exp_sub(x):
y = paddle.exp(x)
z = y - x
Expand Down Expand Up @@ -62,16 +57,21 @@ def prepare_data(self):
self.x = paddle.randn(self.shape, dtype="float32")
self.x.stop_gradient = False

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 1)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1})

def eval(self, use_cinn):
paddle.seed(2022)
net = CINNSubGraphNet()
input_spec = [
InputSpec(shape=[None, 96], dtype='float32'),
]
if use_cinn:
net = apply_to_static(net, use_cinn, input_spec)
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval(self):
Expand Down

0 comments on commit 5f70fcd

Please sign in to comment.