Skip to content

Commit

Permalink
[PIR]Migrate BatchNorm2D into pir (#58113)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Oct 19, 2023
1 parent 0d531a7 commit cfd1d0e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 14 deletions.
2 changes: 1 addition & 1 deletion python/paddle/nn/functional/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def batch_norm(
else:
trainable_statistics = not use_global_stats

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm(
x,
running_mean,
Expand Down
20 changes: 15 additions & 5 deletions python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@

from ...base import dygraph_utils
from ...base.data_feeder import check_variable_and_dtype
from ...framework import ParamAttr, _global_flags, get_default_dtype, no_grad
from ...framework import (
ParamAttr,
_global_flags,
get_default_dtype,
in_dynamic_or_pir_mode,
no_grad,
)
from .. import functional as F
from ..functional import batch_norm, instance_norm, layer_norm
from ..initializer import Constant, Normal
Expand Down Expand Up @@ -1076,7 +1082,7 @@ def __init__(
self._trainable_statistics = trainable_statistics

def forward(self, input):
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm(
input,
self._mean,
Expand All @@ -1092,9 +1098,13 @@ def forward(self, input):
)
if self._act is None:
return batch_norm_out
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
)
if in_dynamic_mode():
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
)
else:
act_op = getattr(_C_ops, self._act)
return act_op(input)
else:
# create output
# mean and mean_out share the same memory
Expand Down
13 changes: 10 additions & 3 deletions test/legacy_test/test_batch_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.base.framework import grad_var_name
from paddle.pir_utils import test_with_pir_api

_set_use_system_allocator(True)

Expand Down Expand Up @@ -857,6 +858,7 @@ def compute(x, is_test, trainable_statistics):
y2 = compute(x, True, True)
np.testing.assert_allclose(y1, y2, rtol=1e-05)

@test_with_pir_api
def test_static(self):
places = [base.CPUPlace()]
if core.is_compiled_with_cuda():
Expand All @@ -866,7 +868,9 @@ def test_static(self):
shape = [4, 10, 16, 16]

def compute(x_np, is_test, trainable_statistics):
with program_guard(Program(), Program()):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
bn = paddle.nn.BatchNorm(
shape[1],
is_test=is_test,
Expand All @@ -876,7 +880,7 @@ def compute(x_np, is_test, trainable_statistics):
name='x', shape=x_np.shape, dtype=x_np.dtype
)
y = bn(x)
exe.run(base.default_startup_program())
exe.run(startup_program)
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r

Expand All @@ -887,8 +891,11 @@ def compute(x_np, is_test, trainable_statistics):


class TestDygraphBatchNormOpenReserveSpace(unittest.TestCase):
@test_with_pir_api
def test_reservespace(self):
with program_guard(Program(), Program()):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
paddle.enable_static()
x = np.random.random(size=(3, 10, 3, 7)).astype('float32')
x = paddle.static.data(name='x', shape=x.shape, dtype=x.dtype)
Expand Down
16 changes: 11 additions & 5 deletions test/legacy_test/test_batch_norm_op_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import paddle
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


class TestBatchNorm(unittest.TestCase):
Expand Down Expand Up @@ -210,6 +211,7 @@ def compute_v4(x):
np.testing.assert_allclose(y1, y2, rtol=1e-05)
np.testing.assert_allclose(y3, y4, rtol=1e-05)

@test_with_pir_api
def test_static(self):
places = [base.CPUPlace()]
if core.is_compiled_with_cuda():
Expand All @@ -219,7 +221,9 @@ def test_static(self):
shape = [4, 10, 16, 16]

def compute_v1(x_np, is_test, trainable_statistics):
with program_guard(Program(), Program()):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with base.program_guard(main_program, startup_program):
bn = paddle.nn.BatchNorm(
shape[1],
is_test=is_test,
Expand All @@ -229,18 +233,20 @@ def compute_v1(x_np, is_test, trainable_statistics):
name='x', shape=x_np.shape, dtype=x_np.dtype
)
y = bn(x)
exe.run(base.default_startup_program())
exe.run(startup_program)
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r

def compute_v2(x_np):
with program_guard(Program(), Program()):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with base.program_guard(main_program, startup_program):
bn = paddle.nn.BatchNorm2D(shape[1])
x = paddle.static.data(
name='x', shape=x_np.shape, dtype=x_np.dtype
)
y = bn(x)
exe.run(base.default_startup_program())
exe.run(startup_program)
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r

Expand Down

0 comments on commit cfd1d0e

Please sign in to comment.