Skip to content

Commit

Permalink
Merge pull request #157 from boeddeker/master
Browse files Browse the repository at this point in the history
fix tests and contrib: lazy import of torch_complex
  • Loading branch information
boeddeker authored Jun 14, 2024
2 parents 2e454d2 + 287f544 commit f6acb39
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
12 changes: 6 additions & 6 deletions padertorch/contrib/cb/complex.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import numpy as np
import torch
import torch_complex
from torch_complex import ComplexTensor

__all__ = {
'ComplexTensor',
}


Expand All @@ -19,8 +16,11 @@ def is_torch(obj):
>>> is_torch(ComplexTensor(np.zeros(3)))
True
"""
if torch.is_tensor(obj) or isinstance(obj, ComplexTensor):
if torch.is_tensor(obj):
return True
else:
return False
if type(obj).__name__ == 'ComplexTensor':
from torch_complex import ComplexTensor
if isinstance(obj, ComplexTensor):
return True
return False

13 changes: 9 additions & 4 deletions padertorch/summary/tbx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def mask_to_image(
is assumed to be in the second position, i.e.,
`(frames, batch [optional], features)`.
color: A color map name. The name is forwarded to
`matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`,
`matplotlib.pyplot.get_cmap` to get the color map. If `None`,
grayscale is used.
origin: Origin of the plot. Can be `'upper'` or `'lower'`.
Expand Down Expand Up @@ -119,7 +119,7 @@ def stft_to_image(
signal: Shape (frames, batch [optional], features)
batch_first: if true mask shape (batch [optional], frames, features]
color: A color map name. The name is forwarded to
`matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`,
`matplotlib.pyplot.get_cmap` to get the color map. If `None`,
grayscale is used.
origin: Origin of the plot. Can be `'upper'` or `'lower'`.
visible_dB: How many dezibel are visible in the image.
Expand Down Expand Up @@ -203,7 +203,7 @@ def __call__(self, image, color):
except KeyError:
try:
import matplotlib.pyplot as plt
cmap = plt.cm.get_cmap(color)
cmap = plt.get_cmap(color)
self.color_to_cmap[color] = cmap
except ImportError:
from warnings import warn
Expand Down Expand Up @@ -243,7 +243,7 @@ def spectrogram_to_image(
is assumed to be in the second position, i.e.,
`(frames, batch [optional], features)`.
color: A color map name. The name is forwarded to
`matplotlib.pyplot.cm.get_cmap` to get the color map.
`matplotlib.pyplot.get_cmap` to get the color map.
origin: Origin of the plot. Can be `'upper'` or `'lower'`.
log: If `True`, the spectrogram is plotted in log domain and shows a
50dB range. The 50dB can be changed with the argument `visible_dB`.
Expand Down Expand Up @@ -299,6 +299,9 @@ def audio(
docs for further information on the return type.
"""
signal = to_numpy(signal, detach=True)
if signal.dtype.kind == 'c':
raise ValueError(
f'Complex datatype ({signal.dtype}) is not supported for audio.')

signal = _remove_batch_axis(signal, batch_first=batch_first, ndim=1)

Expand Down Expand Up @@ -454,3 +457,5 @@ def review_dict(
assert operator.xor(loss is None, losses is None), (loss, losses)

return review


59 changes: 44 additions & 15 deletions tests/test_train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,11 @@ def test_released_tensors():
dt_dataset = dt_dataset[:2]

class ReleaseTestHook(pt.train.hooks.Hook):
def get_all_tensors(self):
def __init__(self, global_tensors):
self.global_tensors = global_tensors

@staticmethod
def get_all_tensors():
import gc
tensors = []
for obj in gc.get_objects():
Expand Down Expand Up @@ -607,16 +611,22 @@ def show_referrers_type(cls, obj, depth, ignore=list()):
ignore=ignore + [referrers, o, obj]
):
l.append(textwrap.indent(s, ' '*4))
else:
l.append('... cycle ...')
class c:
magenta = '\033[35m'
reset = '\033[0m'
cyan = '\033[36m'

if inspect.isframe(obj):
frame_info = inspect.getframeinfo(obj, context=1)
if frame_info.function == 'show_referrers_type':
pass
else:
info = f' {frame_info.function}, {frame_info.filename}:{frame_info.lineno}'
info = f' {frame_info.function}, {c.magenta}{frame_info.filename}{c.reset}:{c.magenta}{frame_info.lineno}{c.reset}'
l.append(f'Frame: {type(obj)} {info}')
else:
l.append(str(type(obj)) + str(obj)[:80].replace('\n', ' '))
l.append(str(type(obj)) + str(obj)[:160].replace('\n', ' '))
return l

def pre_step(self, trainer: 'pt.Trainer'):
Expand Down Expand Up @@ -645,17 +655,18 @@ def pre_step(self, trainer: 'pt.Trainer'):
import textwrap
print(len(all_tensors), len(parameters), len(optimizer_tensors))

assert len(all_tensors) == len(parameters) + len(optimizer_tensors) + len(grads), (
def format_(name, tensors):
s = textwrap.indent("\n".join(map(str, all_tensors)), " "*8)
return f'{name}: {len(tensors)}\n{s}\n'

assert len(all_tensors) == len(parameters) + len(optimizer_tensors) + len(grads) + len(self.global_tensors), (
f'pre_step\n'
f'{summary}\n'
f'all_tensors: {len(all_tensors)}\n'
+ textwrap.indent("\n".join(map(str, all_tensors)), " "*8) + f'\n'
f'parameters: {len(parameters)}\n'
+ textwrap.indent("\n".join(map(str, parameters)), " "*8) + f'\n'
f'parameters: {len(grads)}\n'
+ textwrap.indent("\n".join(map(str, grads)), " "*8) + f'\n'
f'optimizer_tensors: {len(optimizer_tensors)}\n'
+ textwrap.indent("\n".join(map(str, optimizer_tensors)), " "*8) + f'\n'
+ format_('all_tensors', all_tensors)
+ format_('parameters', parameters)
+ format_('optimizer_tensors', optimizer_tensors)
+ format_('grads', grads)
+ format_('global_tensors', self.global_tensors)
)

def post_step(
Expand All @@ -665,12 +676,30 @@ def post_step(
parameters = list(trainer.model.parameters())
assert len(all_tensors) > len(parameters), ('post_step', all_tensors, parameters)


print('pre TemporaryDirectory', ReleaseTestHook.get_all_tensors())

try:
# Between Torch 2.1.2 and 2.3.1 someone created _nt_view_dummy,
# which is the only Tensor in torch, that is created with an import
# of torch code.
# For some unknown reason the Adam optimizer triggers this import
# with the __init__ call.
# Do it here manually to be able to find all "global" tensors.
from torch.nested._internal.nested_tensor import _nt_view_dummy
except Exception:
pass

global_tensors = ReleaseTestHook.get_all_tensors()

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)

model = Model()
optimizer = pt.optimizer.Adam()
t = pt.Trainer(
Model(),
optimizer=pt.optimizer.Adam(),
model=model,
optimizer=optimizer,
storage_dir=str(tmp_dir),
stop_trigger=(1, 'epoch'),
summary_trigger=(1, 'epoch'),
Expand All @@ -679,7 +708,7 @@ def post_step(
t.register_validation_hook(
validation_iterator=dt_dataset, max_checkpoints=None
)
t.register_hook(ReleaseTestHook()) # This hook will do the tests
t.register_hook(ReleaseTestHook(global_tensors)) # This hook will do the tests
t.train(tr_dataset)


Expand Down

0 comments on commit f6acb39

Please sign in to comment.