From 14296691438aaa4d24b101d5c3950d873712eb6d Mon Sep 17 00:00:00 2001 From: Christoph Boeddeker Date: Thu, 13 Jun 2024 14:35:16 +0200 Subject: [PATCH] fix tbx_utils for matplotlib>=3.9 (See https://matplotlib.org/stable/api/prev_api_changes/api_changes_3.9.0.html) --- padertorch/summary/tbx_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/padertorch/summary/tbx_utils.py b/padertorch/summary/tbx_utils.py index 09c26f33..8076d756 100644 --- a/padertorch/summary/tbx_utils.py +++ b/padertorch/summary/tbx_utils.py @@ -79,7 +79,7 @@ def mask_to_image( is assumed to be in the second position, i.e., `(frames, batch [optional], features)`. color: A color map name. The name is forwarded to - `matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`, + `matplotlib.pyplot.get_cmap` to get the color map. If `None`, grayscale is used. origin: Origin of the plot. Can be `'upper'` or `'lower'`. @@ -119,7 +119,7 @@ def stft_to_image( signal: Shape (frames, batch [optional], features) batch_first: if true mask shape (batch [optional], frames, features] color: A color map name. The name is forwarded to - `matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`, + `matplotlib.pyplot.get_cmap` to get the color map. If `None`, grayscale is used. origin: Origin of the plot. Can be `'upper'` or `'lower'`. visible_dB: How many dezibel are visible in the image. @@ -203,7 +203,7 @@ def __call__(self, image, color): except KeyError: try: import matplotlib.pyplot as plt - cmap = plt.cm.get_cmap(color) + cmap = plt.get_cmap(color) self.color_to_cmap[color] = cmap except ImportError: from warnings import warn @@ -243,7 +243,7 @@ def spectrogram_to_image( is assumed to be in the second position, i.e., `(frames, batch [optional], features)`. color: A color map name. The name is forwarded to - `matplotlib.pyplot.cm.get_cmap` to get the color map. + `matplotlib.pyplot.get_cmap` to get the color map. origin: Origin of the plot. Can be `'upper'` or `'lower'`. log: If `True`, the spectrogram is plotted in log domain and shows a 50dB range. The 50dB can be changed with the argument `visible_dB`. @@ -299,6 +299,9 @@ def audio( docs for further information on the return type. """ signal = to_numpy(signal, detach=True) + if signal.dtype.kind == 'c': + raise ValueError( + f'Complex datatype ({signal.dtype}) is not supported for audio.') signal = _remove_batch_axis(signal, batch_first=batch_first, ndim=1) @@ -454,3 +457,5 @@ def review_dict( assert operator.xor(loss is None, losses is None), (loss, losses) return review + +