Skip to content

Commit

Permalink
Merge pull request matplotlib#27562 from QuLogic/no-alpha-copy
Browse files Browse the repository at this point in the history
Avoid an extra copy/resample if imshow input has no alpha
  • Loading branch information
jklymak authored Jan 10, 2024
2 parents 7f8b9b3 + 3d6a349 commit 8542398
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
12 changes: 8 additions & 4 deletions lib/matplotlib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,11 +555,15 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0,
if A.ndim == 2: # _interpolation_stage == 'rgba'
self.norm.autoscale_None(A)
A = self.to_rgba(A)
if A.shape[2] == 3:
A = _rgb_to_rgba(A)
alpha = self._get_scalar_alpha()
output_alpha = _resample( # resample alpha channel
self, A[..., 3], out_shape, t, alpha=alpha)
if A.shape[2] == 3:
# No need to resample alpha or make a full array; NumPy will expand
# this out and cast to uint8 if necessary when it's assigned to the
# alpha channel below.
output_alpha = (255 * alpha) if A.dtype == np.uint8 else alpha
else:
output_alpha = _resample( # resample alpha channel
self, A[..., 3], out_shape, t, alpha=alpha)
output = _resample( # resample rgb channels
self, _rgb_to_rgba(A[..., :3]), out_shape, t, alpha=alpha)
output[..., 3] = output_alpha # recombine rgb and alpha
Expand Down
26 changes: 26 additions & 0 deletions lib/matplotlib/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,32 @@ def test_image_alpha():
ax3.imshow(Z, alpha=0.5, interpolation='nearest')


@mpl.style.context('mpl20')
@check_figures_equal(extensions=['png'])
def test_imshow_alpha(fig_test, fig_ref):
np.random.seed(19680801)

rgbf = np.random.rand(6, 6, 3)
rgbu = np.uint8(rgbf * 255)
((ax0, ax1), (ax2, ax3)) = fig_test.subplots(2, 2)
ax0.imshow(rgbf, alpha=0.5)
ax1.imshow(rgbf, alpha=0.75)
ax2.imshow(rgbu, alpha=0.5)
ax3.imshow(rgbu, alpha=0.75)

rgbaf = np.concatenate((rgbf, np.ones((6, 6, 1))), axis=2)
rgbau = np.concatenate((rgbu, np.full((6, 6, 1), 255, np.uint8)), axis=2)
((ax0, ax1), (ax2, ax3)) = fig_ref.subplots(2, 2)
rgbaf[:, :, 3] = 0.5
ax0.imshow(rgbaf)
rgbaf[:, :, 3] = 0.75
ax1.imshow(rgbaf)
rgbau[:, :, 3] = 127
ax2.imshow(rgbau)
rgbau[:, :, 3] = 191
ax3.imshow(rgbau)


def test_cursor_data():
from matplotlib.backend_bases import MouseEvent

Expand Down

0 comments on commit 8542398

Please sign in to comment.