Skip to content

Commit

Permalink
fix tbx_utils for matplotlib>=3.9 (See https://matplotlib.org/stable/…
Browse files Browse the repository at this point in the history
  • Loading branch information
boeddeker committed Jun 13, 2024
1 parent 94c1f18 commit 1429669
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions padertorch/summary/tbx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def mask_to_image(
is assumed to be in the second position, i.e.,
`(frames, batch [optional], features)`.
color: A color map name. The name is forwarded to
`matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`,
`matplotlib.pyplot.get_cmap` to get the color map. If `None`,
grayscale is used.
origin: Origin of the plot. Can be `'upper'` or `'lower'`.
Expand Down Expand Up @@ -119,7 +119,7 @@ def stft_to_image(
signal: Shape (frames, batch [optional], features)
batch_first: if true mask shape (batch [optional], frames, features]
color: A color map name. The name is forwarded to
`matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`,
`matplotlib.pyplot.get_cmap` to get the color map. If `None`,
grayscale is used.
origin: Origin of the plot. Can be `'upper'` or `'lower'`.
visible_dB: How many dezibel are visible in the image.
Expand Down Expand Up @@ -203,7 +203,7 @@ def __call__(self, image, color):
except KeyError:
try:
import matplotlib.pyplot as plt
cmap = plt.cm.get_cmap(color)
cmap = plt.get_cmap(color)
self.color_to_cmap[color] = cmap
except ImportError:
from warnings import warn
Expand Down Expand Up @@ -243,7 +243,7 @@ def spectrogram_to_image(
is assumed to be in the second position, i.e.,
`(frames, batch [optional], features)`.
color: A color map name. The name is forwarded to
`matplotlib.pyplot.cm.get_cmap` to get the color map.
`matplotlib.pyplot.get_cmap` to get the color map.
origin: Origin of the plot. Can be `'upper'` or `'lower'`.
log: If `True`, the spectrogram is plotted in log domain and shows a
50dB range. The 50dB can be changed with the argument `visible_dB`.
Expand Down Expand Up @@ -299,6 +299,9 @@ def audio(
docs for further information on the return type.
"""
signal = to_numpy(signal, detach=True)
if signal.dtype.kind == 'c':
raise ValueError(
f'Complex datatype ({signal.dtype}) is not supported for audio.')

signal = _remove_batch_axis(signal, batch_first=batch_first, ndim=1)

Expand Down Expand Up @@ -454,3 +457,5 @@ def review_dict(
assert operator.xor(loss is None, losses is None), (loss, losses)

return review


0 comments on commit 1429669

Please sign in to comment.