Skip to content

Commit

Permalink
Merge pull request #159 from boeddeker/master
Browse files Browse the repository at this point in the history
fix doctests for numpy 2
  • Loading branch information
boeddeker authored Jun 18, 2024
2 parents e673eb4 + eaefe49 commit da183b4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
7 changes: 5 additions & 2 deletions padertorch/data/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def _get_rand_int(rng, *args, **kwargs):

def get_anchor(
num_samples: int, length: int, shift: int = None,
mode: str = 'left', rng=np.random
mode: str = 'left', rng: 'np.random.RandomState' = np.random
) -> int:
"""
Calculates anchor for the boundaries for segmentation of a signal
Expand All @@ -372,6 +372,9 @@ def get_anchor(
random_max_segments: Randomly chooses the anchor such that
the maximum number of segments are created
rng: random number generator (`numpy.random`)
e.g. `np.random.RandomState(0)` or `np.random.default_rng(0)`.
Not, that `np.random.default_rng` differs from `np.random` and has
sometimes other return types (e.g. python `int` vs `np.int64`)
Returns:
integer value describing the anchor
Expand All @@ -387,7 +390,7 @@ def get_anchor(
1
>>> get_anchor(24, 10, 3, mode='random')
10
>>> get_anchor(24, 10, 3, mode='random', rng=np.random.default_rng(seed=4))
>>> get_anchor(24, 10, 3, mode='random', rng=np.random.RandomState(seed=4))
10
>>> get_anchor(24, 10, 3, mode='random_max_segments')
3
Expand Down
4 changes: 2 additions & 2 deletions padertorch/ops/losses/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,15 @@ def si_sdr_loss(estimate, target, reduction='mean', offset_invariant=False,
... print('Numpy metric:', si_sdr(estimate.numpy(), target.numpy()))
Perfect estimation
>>> si_sdr(reference.numpy(), reference.numpy())
>>> print(si_sdr(reference.numpy(), reference.numpy()))
inf
>>> sdr_loss(reference, reference)
tensor(-inf, dtype=torch.float64)
>>> si_sdr_loss(reference, reference) < -300 # Torch CPU is not hardware independent
tensor(True)
>>> si_sdr_loss(reference.to(torch.float32), reference.to(torch.float32)) < -130 # Torch CPU is not hardware independent
tensor(True)
>>> si_sdr(reference.numpy(), (reference * 2).numpy())
>>> print(si_sdr(reference.numpy(), (reference * 2).numpy()))
inf
>>> si_sdr_loss(reference, reference * 2) < -300 # Torch CPU is not hardware independent
tensor(True)
Expand Down
6 changes: 4 additions & 2 deletions padertorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ def to_list(x, length=None):
Complicated corner cases are e.g. `range()` and `dict.values()`, which are
handled here.
>>> from paderbox.utils.pretty import pprint
>>> to_list(1)
[1]
>>> to_list([1])
[1]
>>> to_list((i for i in range(3)))
[0, 1, 2]
>>> to_list(np.arange(3))
[0, 1, 2]
>>> pprint(to_list(np.arange(3)), nep51=True) # use pprint to support numpy 1 and 2
[np.int64(0), np.int64(1), np.int64(2)]
>>> to_list({'a': 1})
[{'a': 1}]
>>> to_list({'a': 1}.keys())
Expand Down

0 comments on commit da183b4

Please sign in to comment.