Skip to content

Commit

Permalink
fix the div 0 errors in psroi_pool (PaddlePaddle#49965)
Browse files Browse the repository at this point in the history
* fix the div 0 errors in psroi_pool

* fix case 7

* rool back sth.
  • Loading branch information
Liyulingyue authored and pangengzheng committed Feb 2, 2023
1 parent 1211722 commit 343d1c1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/paddle/fluid/tests/unittests/test_psroi_pool_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,22 @@ def test_channel_error():
self.assertRaises(ValueError, test_channel_error)


class TestPSROIPoolZeroDivError(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.x = paddle.uniform([2, 490, 28, 28], dtype='float32')
self.boxes = paddle.to_tensor(
[[1, 5, 8, 10], [4, 2, 6, 7], [12, 12, 19, 21]], dtype='float32'
)
self.boxes_num = paddle.to_tensor([1, 2], dtype='int32')

def test_errors(self):
def test_zero_div_error():
paddle.vision.ops.psroi_pool(self.x, self.boxes, self.boxes_num, 0)

self.assertRaises(ValueError, test_zero_div_error)


class TestPSROIPoolStaticAPI(unittest.TestCase):
def setUp(self):
paddle.enable_static()
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/vision/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,8 @@ def psroi_pool(x, boxes, boxes_num, output_size, spatial_scale=1.0, name=None):
output_size = (output_size, output_size)
pooled_height, pooled_width = output_size
assert len(x.shape) == 4, "Input features with shape should be (N, C, H, W)"
if pooled_height * pooled_width == 0:
raise ValueError('output_size should not contain 0.')
output_channels = int(x.shape[1] / (pooled_height * pooled_width))
if in_dygraph_mode():
return _C_ops.psroi_pool(
Expand Down

0 comments on commit 343d1c1

Please sign in to comment.