diff --git a/padertorch/summary/tbx_utils.py b/padertorch/summary/tbx_utils.py index 09c26f33..8076d756 100644 --- a/padertorch/summary/tbx_utils.py +++ b/padertorch/summary/tbx_utils.py @@ -79,7 +79,7 @@ def mask_to_image( is assumed to be in the second position, i.e., `(frames, batch [optional], features)`. color: A color map name. The name is forwarded to - `matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`, + `matplotlib.pyplot.get_cmap` to get the color map. If `None`, grayscale is used. origin: Origin of the plot. Can be `'upper'` or `'lower'`. @@ -119,7 +119,7 @@ def stft_to_image( signal: Shape (frames, batch [optional], features) batch_first: if true mask shape (batch [optional], frames, features] color: A color map name. The name is forwarded to - `matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`, + `matplotlib.pyplot.get_cmap` to get the color map. If `None`, grayscale is used. origin: Origin of the plot. Can be `'upper'` or `'lower'`. visible_dB: How many dezibel are visible in the image. @@ -203,7 +203,7 @@ def __call__(self, image, color): except KeyError: try: import matplotlib.pyplot as plt - cmap = plt.cm.get_cmap(color) + cmap = plt.get_cmap(color) self.color_to_cmap[color] = cmap except ImportError: from warnings import warn @@ -243,7 +243,7 @@ def spectrogram_to_image( is assumed to be in the second position, i.e., `(frames, batch [optional], features)`. color: A color map name. The name is forwarded to - `matplotlib.pyplot.cm.get_cmap` to get the color map. + `matplotlib.pyplot.get_cmap` to get the color map. origin: Origin of the plot. Can be `'upper'` or `'lower'`. log: If `True`, the spectrogram is plotted in log domain and shows a 50dB range. The 50dB can be changed with the argument `visible_dB`. @@ -299,6 +299,9 @@ def audio( docs for further information on the return type. """ signal = to_numpy(signal, detach=True) + if signal.dtype.kind == 'c': + raise ValueError( + f'Complex datatype ({signal.dtype}) is not supported for audio.') signal = _remove_batch_axis(signal, batch_first=batch_first, ndim=1) @@ -454,3 +457,5 @@ def review_dict( assert operator.xor(loss is None, losses is None), (loss, losses) return review + +