Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tests and contrib: lazy import of torch_complex #157

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions padertorch/contrib/cb/complex.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import numpy as np
import torch
import torch_complex
from torch_complex import ComplexTensor

__all__ = {
'ComplexTensor',
}


Expand All @@ -19,8 +16,11 @@ def is_torch(obj):
>>> is_torch(ComplexTensor(np.zeros(3)))
True
"""
if torch.is_tensor(obj) or isinstance(obj, ComplexTensor):
if torch.is_tensor(obj):
return True
else:
return False
if type(obj).__name__ == 'ComplexTensor':
from torch_complex import ComplexTensor
if isinstance(obj, ComplexTensor):
return True
return False

13 changes: 9 additions & 4 deletions padertorch/summary/tbx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def mask_to_image(
is assumed to be in the second position, i.e.,
`(frames, batch [optional], features)`.
color: A color map name. The name is forwarded to
`matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`,
`matplotlib.pyplot.get_cmap` to get the color map. If `None`,
grayscale is used.
origin: Origin of the plot. Can be `'upper'` or `'lower'`.

Expand Down Expand Up @@ -119,7 +119,7 @@ def stft_to_image(
signal: Shape (frames, batch [optional], features)
batch_first: if true mask shape (batch [optional], frames, features]
color: A color map name. The name is forwarded to
`matplotlib.pyplot.cm.get_cmap` to get the color map. If `None`,
`matplotlib.pyplot.get_cmap` to get the color map. If `None`,
grayscale is used.
origin: Origin of the plot. Can be `'upper'` or `'lower'`.
visible_dB: How many dezibel are visible in the image.
Expand Down Expand Up @@ -203,7 +203,7 @@ def __call__(self, image, color):
except KeyError:
try:
import matplotlib.pyplot as plt
cmap = plt.cm.get_cmap(color)
cmap = plt.get_cmap(color)
self.color_to_cmap[color] = cmap
except ImportError:
from warnings import warn
Expand Down Expand Up @@ -243,7 +243,7 @@ def spectrogram_to_image(
is assumed to be in the second position, i.e.,
`(frames, batch [optional], features)`.
color: A color map name. The name is forwarded to
`matplotlib.pyplot.cm.get_cmap` to get the color map.
`matplotlib.pyplot.get_cmap` to get the color map.
origin: Origin of the plot. Can be `'upper'` or `'lower'`.
log: If `True`, the spectrogram is plotted in log domain and shows a
50dB range. The 50dB can be changed with the argument `visible_dB`.
Expand Down Expand Up @@ -299,6 +299,9 @@ def audio(
docs for further information on the return type.
"""
signal = to_numpy(signal, detach=True)
if signal.dtype.kind == 'c':
raise ValueError(
f'Complex datatype ({signal.dtype}) is not supported for audio.')

signal = _remove_batch_axis(signal, batch_first=batch_first, ndim=1)

Expand Down Expand Up @@ -454,3 +457,5 @@ def review_dict(
assert operator.xor(loss is None, losses is None), (loss, losses)

return review


59 changes: 44 additions & 15 deletions tests/test_train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,11 @@ def test_released_tensors():
dt_dataset = dt_dataset[:2]

class ReleaseTestHook(pt.train.hooks.Hook):
def get_all_tensors(self):
def __init__(self, global_tensors):
self.global_tensors = global_tensors

@staticmethod
def get_all_tensors():
import gc
tensors = []
for obj in gc.get_objects():
Expand Down Expand Up @@ -607,16 +611,22 @@ def show_referrers_type(cls, obj, depth, ignore=list()):
ignore=ignore + [referrers, o, obj]
):
l.append(textwrap.indent(s, ' '*4))
else:
l.append('... cycle ...')
class c:
magenta = '\033[35m'
reset = '\033[0m'
cyan = '\033[36m'

if inspect.isframe(obj):
frame_info = inspect.getframeinfo(obj, context=1)
if frame_info.function == 'show_referrers_type':
pass
else:
info = f' {frame_info.function}, {frame_info.filename}:{frame_info.lineno}'
info = f' {frame_info.function}, {c.magenta}{frame_info.filename}{c.reset}:{c.magenta}{frame_info.lineno}{c.reset}'
l.append(f'Frame: {type(obj)} {info}')
else:
l.append(str(type(obj)) + str(obj)[:80].replace('\n', ' '))
l.append(str(type(obj)) + str(obj)[:160].replace('\n', ' '))
return l

def pre_step(self, trainer: 'pt.Trainer'):
Expand Down Expand Up @@ -645,17 +655,18 @@ def pre_step(self, trainer: 'pt.Trainer'):
import textwrap
print(len(all_tensors), len(parameters), len(optimizer_tensors))

assert len(all_tensors) == len(parameters) + len(optimizer_tensors) + len(grads), (
def format_(name, tensors):
s = textwrap.indent("\n".join(map(str, all_tensors)), " "*8)
return f'{name}: {len(tensors)}\n{s}\n'

assert len(all_tensors) == len(parameters) + len(optimizer_tensors) + len(grads) + len(self.global_tensors), (
f'pre_step\n'
f'{summary}\n'
f'all_tensors: {len(all_tensors)}\n'
+ textwrap.indent("\n".join(map(str, all_tensors)), " "*8) + f'\n'
f'parameters: {len(parameters)}\n'
+ textwrap.indent("\n".join(map(str, parameters)), " "*8) + f'\n'
f'parameters: {len(grads)}\n'
+ textwrap.indent("\n".join(map(str, grads)), " "*8) + f'\n'
f'optimizer_tensors: {len(optimizer_tensors)}\n'
+ textwrap.indent("\n".join(map(str, optimizer_tensors)), " "*8) + f'\n'
+ format_('all_tensors', all_tensors)
+ format_('parameters', parameters)
+ format_('optimizer_tensors', optimizer_tensors)
+ format_('grads', grads)
+ format_('global_tensors', self.global_tensors)
)

def post_step(
Expand All @@ -665,12 +676,30 @@ def post_step(
parameters = list(trainer.model.parameters())
assert len(all_tensors) > len(parameters), ('post_step', all_tensors, parameters)


print('pre TemporaryDirectory', ReleaseTestHook.get_all_tensors())

try:
# Between Torch 2.1.2 and 2.3.1 someone created _nt_view_dummy,
# which is the only Tensor in torch, that is created with an import
# of torch code.
# For some unknown reason the Adam optimizer triggers this import
# with the __init__ call.
# Do it here manually to be able to find all "global" tensors.
from torch.nested._internal.nested_tensor import _nt_view_dummy
except Exception:
pass

global_tensors = ReleaseTestHook.get_all_tensors()

with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)

model = Model()
optimizer = pt.optimizer.Adam()
t = pt.Trainer(
Model(),
optimizer=pt.optimizer.Adam(),
model=model,
optimizer=optimizer,
storage_dir=str(tmp_dir),
stop_trigger=(1, 'epoch'),
summary_trigger=(1, 'epoch'),
Expand All @@ -679,7 +708,7 @@ def post_step(
t.register_validation_hook(
validation_iterator=dt_dataset, max_checkpoints=None
)
t.register_hook(ReleaseTestHook()) # This hook will do the tests
t.register_hook(ReleaseTestHook(global_tensors)) # This hook will do the tests
t.train(tr_dataset)


Expand Down
Loading