Skip to content

Commit

Permalink
Add interp dimension in nearest-neighbors interpolation to match line…
Browse files Browse the repository at this point in the history
…ar interpolation (#177)

* Add interp dimension in nearest-neighbors interpolation to match

* Add test for using nearest-neighbors interpolation for delay-and-sum

* Lint code
  • Loading branch information
charlesbmi authored Nov 8, 2023
1 parent 113f0f6 commit f560f29
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/neurotechdevkit/imaging/beamform.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ def _interpolate_onto_pixel_grid(
time_idx_float <= (max_time_samples - 1)
)
time_idx_round = time_idx_float.round()
# Add dimension to match other interpolation methods
time_idx_round = time_idx_round.expand_dims("interp")
elif method == InterpolationMethod.LINEAR:
# E.g., a time index of 2.3 gets 0.7*pixel_2 + 0.3*pixel_3
is_valid_time_idx = (time_idx_float >= 0) & (
Expand Down
17 changes: 17 additions & 0 deletions tests/neurotechdevkit/imaging/test_beamform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from scipy.sparse import csr_array

from neurotechdevkit.imaging.beamform import (
InterpolationMethod,
_calculate_time_of_flight,
_directivity,
_optimize_f_number,
Expand Down Expand Up @@ -102,6 +103,22 @@ def test_delay_and_sum(simple_inputs):
assert beamformed_iq_signals.shape == simple_inputs["x"].shape


def test_delay_and_sum_nearest_neighbors_interpolation(simple_inputs):
simple_inputs["method"] = InterpolationMethod.NEAREST

num_time_samples = simple_inputs["num_time_samples"]
num_channels = simple_inputs["num_channels"]

iq_signals = np.random.rand(num_time_samples, num_channels) + 1j * np.random.rand(
num_time_samples, num_channels
)
del simple_inputs["num_time_samples"]
del simple_inputs["num_channels"]

beamformed_iq_signals = beamform_delay_and_sum(iq_signals, **simple_inputs)
assert beamformed_iq_signals.shape == simple_inputs["x"].shape


def test_expect_complex(simple_inputs):
# create a scenario where AssertionError is raised when iq_signals is not complex.
with pytest.raises(AssertionError):
Expand Down

0 comments on commit f560f29

Please sign in to comment.