From 6c0e0749a811b55f83b8d045185f20ac3a8941ba Mon Sep 17 00:00:00 2001 From: charles <3221512+charlesbmi@users.noreply.github.com> Date: Tue, 30 Jul 2024 02:07:55 -0700 Subject: [PATCH] Fix plot_time_reverse.py time-reverse simulation (#172) * Update plot_time_reverse.py to refactored NDK API where there is a default source * use _metrics to calculate focal_position in time-reverse example * No need to print delay label coordinates * Fix type-hinting errors * Allow tuple value in metrics * Fix linting * Fix assert to support np.int_ return values --------- Co-authored-by: Charles Guan <3221512+charlesincharge@users.noreply.github.com> --- docs/examples/plot_time_reverse.py | 12 ++++++------ src/neurotechdevkit/rendering/colormaps.py | 3 +-- src/neurotechdevkit/results/_metrics.py | 9 ++++++++- src/neurotechdevkit/results/_results.py | 2 +- .../scenarios/built_in/_scenario_1.py | 1 - src/neurotechdevkit/sources.py | 1 - tests/neurotechdevkit/scenarios/test_metrics.py | 7 ++++++- 7 files changed, 22 insertions(+), 13 deletions(-) diff --git a/docs/examples/plot_time_reverse.py b/docs/examples/plot_time_reverse.py index 7f507810..e76f8f0e 100644 --- a/docs/examples/plot_time_reverse.py +++ b/docs/examples/plot_time_reverse.py @@ -73,7 +73,7 @@ def make_scenario(element_delays=None): point_source = ndk.sources.PointSource2D( position=true_scenario.target.center, ) -reversed_scenario.sources.append(point_source) +reversed_scenario.sources = [point_source] reversed_scenario.make_grid() reversed_scenario.compile_problem() @@ -91,6 +91,7 @@ def make_scenario(element_delays=None): # the true array elements. Here, we coarsely approximate these delays by # finding the pressure argmax at each element's nearest-neighbor coordinates. + # Map array elements onto the nearest pixels in our simulation def map_coordinates_to_indices(coordinates, origin, dx): indices = np.round((coordinates - origin) / dx).astype(int) @@ -111,7 +112,7 @@ def map_coordinates_to_indices(coordinates, origin, dx): element_reverse_delays = np.argmax(pressure_at_elements, axis=1) * result.effective_dt plt.plot(element_reverse_delays, marker="o") plt.xlabel("element index") -plt.ylabel("delay [s]") +_ = plt.ylabel("delay [s]") # %% @@ -175,10 +176,9 @@ def map_coordinates_to_indices(coordinates, origin, dx): # %% # We can also calculate how far the "time reverse" estimate is from the true # target. -max_pressure_flat_idx = np.nanargmax(steady_state_pressure) -max_pressure_idx = np.unravel_index(max_pressure_flat_idx, steady_state_pressure.shape) -max_pressure_idx - +max_pressure_idx = steady_state_result.metrics["focal_position"]["value"] +assert isinstance(max_pressure_idx, tuple) +assert all(np.issubdtype(type(idx), np.integer) for idx in max_pressure_idx) grid = steady_state_result.traces.grid.space.grid focal_point = np.array( [ diff --git a/src/neurotechdevkit/rendering/colormaps.py b/src/neurotechdevkit/rendering/colormaps.py index d2ba2aa8..b28e79b9 100644 --- a/src/neurotechdevkit/rendering/colormaps.py +++ b/src/neurotechdevkit/rendering/colormaps.py @@ -1387,14 +1387,13 @@ from matplotlib.colors import ListedColormap # noqa: E402 cmaps = {} -for (name, data) in ( +for name, data in ( ("magma", _magma_data), ("inferno", _inferno_data), ("plasma", _plasma_data), ("viridis", _viridis_data), ("parula", _parula_data), ): - cmaps[name] = ListedColormap(data, name=name) magma = cmaps["magma"] diff --git a/src/neurotechdevkit/results/_metrics.py b/src/neurotechdevkit/results/_metrics.py index b7f2ac38..44a309bf 100644 --- a/src/neurotechdevkit/results/_metrics.py +++ b/src/neurotechdevkit/results/_metrics.py @@ -11,7 +11,7 @@ def calculate_all_metrics( result: results.SteadyStateResult, -) -> dict[str, dict[str, float | int | str]]: +) -> dict[str, dict[str, float | int | str | tuple]]: """Calculate all metrics for the steady-state result and return as a dictionary. The keys for the dictionary are the names of the metrics. The value for each metric @@ -33,6 +33,13 @@ def calculate_all_metrics( "unit-of-measurement": "Pa", "description": ("The peak pressure amplitude within the brain."), }, + "focal_position": { + "value": calculate_focal_position(result, layer="brain"), + "unit-of-measurement": "voxel-index", + "description": ( + "The position of the peak pressure amplitude within the brain." + ), + }, "focal_volume": { "value": calculate_focal_volume(result, layer="brain"), "unit-of-measurement": "voxels", diff --git a/src/neurotechdevkit/results/_results.py b/src/neurotechdevkit/results/_results.py index d92ebd83..dd36f6c0 100644 --- a/src/neurotechdevkit/results/_results.py +++ b/src/neurotechdevkit/results/_results.py @@ -183,7 +183,7 @@ def get_steady_state(self) -> npt.NDArray[np.float_]: return self.steady_state @property - def metrics(self) -> dict[str, dict[str, str | float]]: + def metrics(self) -> dict[str, dict[str, str | float | int | tuple]]: """A dictionary containing metrics and their descriptions. The keys for the dictionary are the names of the metrics. The value for each diff --git a/src/neurotechdevkit/scenarios/built_in/_scenario_1.py b/src/neurotechdevkit/scenarios/built_in/_scenario_1.py index a2088915..38ef1b13 100644 --- a/src/neurotechdevkit/scenarios/built_in/_scenario_1.py +++ b/src/neurotechdevkit/scenarios/built_in/_scenario_1.py @@ -167,7 +167,6 @@ def make_grid(self): def _create_scenario_1_mask(material, grid): - # layers are defined by X position dx = grid.space.spacing[0] diff --git a/src/neurotechdevkit/sources.py b/src/neurotechdevkit/sources.py index f1db86e1..89bc4b2d 100644 --- a/src/neurotechdevkit/sources.py +++ b/src/neurotechdevkit/sources.py @@ -1323,7 +1323,6 @@ def _distribute_points_within_element( n_remaining = n_points - points.shape[0] if n_remaining > 0: - # First compute the center x_width = x_max - x_min centre = np.array([(x_min + x_max) / 2, height / 2]) diff --git a/tests/neurotechdevkit/scenarios/test_metrics.py b/tests/neurotechdevkit/scenarios/test_metrics.py index e0bc0605..88252c4b 100644 --- a/tests/neurotechdevkit/scenarios/test_metrics.py +++ b/tests/neurotechdevkit/scenarios/test_metrics.py @@ -42,6 +42,9 @@ def patched_metric_fns(monkeypatch): monkeypatch.setattr( metrics, "calculate_focal_pressure", lambda result, layer: 8.2e5 ) + monkeypatch.setattr( + metrics, "calculate_focal_position", lambda result, layer: (15, 24) + ) monkeypatch.setattr(metrics, "calculate_focal_volume", lambda result, layer: 123) monkeypatch.setattr(metrics, "calculate_focal_gain", lambda result: 2.4) monkeypatch.setattr(metrics, "calculate_focal_fwhm", lambda result, axis, layer: 6) @@ -129,13 +132,14 @@ def test_calculate_all_metrics_has_correct_structure(fake_result, patched_metric metrics = calculate_all_metrics(fake_result) for _, data in metrics.items(): assert set(data.keys()) == {"value", "unit-of-measurement", "description"} - assert isinstance(data["value"], (float, int)) + assert isinstance(data["value"], (float, int, tuple)) def test_calculate_all_metrics_has_expected_metrics(fake_result, patched_metric_fns): """Verify that the metrics data contains the expected set of metrics""" expected_metrics = [ "focal_pressure", + "focal_position", "focal_volume", "focal_gain", "FWHM_x", @@ -158,6 +162,7 @@ def test_calculate_all_metrics_conversions(fake_result, patched_metric_fns): """ metrics = calculate_all_metrics(fake_result) np.testing.assert_allclose(metrics["focal_pressure"]["value"], 8.2e5) + assert metrics["focal_position"]["value"] == (15, 24) assert metrics["focal_volume"]["value"] == 123 np.testing.assert_allclose(metrics["focal_gain"]["value"], 2.4) assert metrics["FWHM_x"]["value"] == 6