Skip to content

Commit

Permalink
Fix plot_time_reverse.py time-reverse simulation (#172)
Browse files Browse the repository at this point in the history
* Update plot_time_reverse.py to refactored NDK API where there is a default source

* use _metrics to calculate focal_position in time-reverse example

* No need to print delay label coordinates

* Fix type-hinting errors

* Allow tuple value in metrics

* Fix linting

* Fix assert to support np.int_ return values
---------

Co-authored-by: Charles Guan <[email protected]>
  • Loading branch information
charlesbmi and charlesbmi authored Jul 30, 2024
1 parent 6bd87e2 commit 6c0e074
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 13 deletions.
12 changes: 6 additions & 6 deletions docs/examples/plot_time_reverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def make_scenario(element_delays=None):
point_source = ndk.sources.PointSource2D(
position=true_scenario.target.center,
)
reversed_scenario.sources.append(point_source)
reversed_scenario.sources = [point_source]

reversed_scenario.make_grid()
reversed_scenario.compile_problem()
Expand All @@ -91,6 +91,7 @@ def make_scenario(element_delays=None):
# the true array elements. Here, we coarsely approximate these delays by
# finding the pressure argmax at each element's nearest-neighbor coordinates.


# Map array elements onto the nearest pixels in our simulation
def map_coordinates_to_indices(coordinates, origin, dx):
indices = np.round((coordinates - origin) / dx).astype(int)
Expand All @@ -111,7 +112,7 @@ def map_coordinates_to_indices(coordinates, origin, dx):
element_reverse_delays = np.argmax(pressure_at_elements, axis=1) * result.effective_dt
plt.plot(element_reverse_delays, marker="o")
plt.xlabel("element index")
plt.ylabel("delay [s]")
_ = plt.ylabel("delay [s]")


# %%
Expand Down Expand Up @@ -175,10 +176,9 @@ def map_coordinates_to_indices(coordinates, origin, dx):
# %%
# We can also calculate how far the "time reverse" estimate is from the true
# target.
max_pressure_flat_idx = np.nanargmax(steady_state_pressure)
max_pressure_idx = np.unravel_index(max_pressure_flat_idx, steady_state_pressure.shape)
max_pressure_idx

max_pressure_idx = steady_state_result.metrics["focal_position"]["value"]
assert isinstance(max_pressure_idx, tuple)
assert all(np.issubdtype(type(idx), np.integer) for idx in max_pressure_idx)
grid = steady_state_result.traces.grid.space.grid
focal_point = np.array(
[
Expand Down
3 changes: 1 addition & 2 deletions src/neurotechdevkit/rendering/colormaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,14 +1387,13 @@
from matplotlib.colors import ListedColormap # noqa: E402

cmaps = {}
for (name, data) in (
for name, data in (
("magma", _magma_data),
("inferno", _inferno_data),
("plasma", _plasma_data),
("viridis", _viridis_data),
("parula", _parula_data),
):

cmaps[name] = ListedColormap(data, name=name)

magma = cmaps["magma"]
Expand Down
9 changes: 8 additions & 1 deletion src/neurotechdevkit/results/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def calculate_all_metrics(
result: results.SteadyStateResult,
) -> dict[str, dict[str, float | int | str]]:
) -> dict[str, dict[str, float | int | str | tuple]]:
"""Calculate all metrics for the steady-state result and return as a dictionary.
The keys for the dictionary are the names of the metrics. The value for each metric
Expand All @@ -33,6 +33,13 @@ def calculate_all_metrics(
"unit-of-measurement": "Pa",
"description": ("The peak pressure amplitude within the brain."),
},
"focal_position": {
"value": calculate_focal_position(result, layer="brain"),
"unit-of-measurement": "voxel-index",
"description": (
"The position of the peak pressure amplitude within the brain."
),
},
"focal_volume": {
"value": calculate_focal_volume(result, layer="brain"),
"unit-of-measurement": "voxels",
Expand Down
2 changes: 1 addition & 1 deletion src/neurotechdevkit/results/_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def get_steady_state(self) -> npt.NDArray[np.float_]:
return self.steady_state

@property
def metrics(self) -> dict[str, dict[str, str | float]]:
def metrics(self) -> dict[str, dict[str, str | float | int | tuple]]:
"""A dictionary containing metrics and their descriptions.
The keys for the dictionary are the names of the metrics. The value for each
Expand Down
1 change: 0 additions & 1 deletion src/neurotechdevkit/scenarios/built_in/_scenario_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def make_grid(self):


def _create_scenario_1_mask(material, grid):

# layers are defined by X position
dx = grid.space.spacing[0]

Expand Down
1 change: 0 additions & 1 deletion src/neurotechdevkit/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,7 +1323,6 @@ def _distribute_points_within_element(
n_remaining = n_points - points.shape[0]

if n_remaining > 0:

# First compute the center
x_width = x_max - x_min
centre = np.array([(x_min + x_max) / 2, height / 2])
Expand Down
7 changes: 6 additions & 1 deletion tests/neurotechdevkit/scenarios/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def patched_metric_fns(monkeypatch):
monkeypatch.setattr(
metrics, "calculate_focal_pressure", lambda result, layer: 8.2e5
)
monkeypatch.setattr(
metrics, "calculate_focal_position", lambda result, layer: (15, 24)
)
monkeypatch.setattr(metrics, "calculate_focal_volume", lambda result, layer: 123)
monkeypatch.setattr(metrics, "calculate_focal_gain", lambda result: 2.4)
monkeypatch.setattr(metrics, "calculate_focal_fwhm", lambda result, axis, layer: 6)
Expand Down Expand Up @@ -129,13 +132,14 @@ def test_calculate_all_metrics_has_correct_structure(fake_result, patched_metric
metrics = calculate_all_metrics(fake_result)
for _, data in metrics.items():
assert set(data.keys()) == {"value", "unit-of-measurement", "description"}
assert isinstance(data["value"], (float, int))
assert isinstance(data["value"], (float, int, tuple))


def test_calculate_all_metrics_has_expected_metrics(fake_result, patched_metric_fns):
"""Verify that the metrics data contains the expected set of metrics"""
expected_metrics = [
"focal_pressure",
"focal_position",
"focal_volume",
"focal_gain",
"FWHM_x",
Expand All @@ -158,6 +162,7 @@ def test_calculate_all_metrics_conversions(fake_result, patched_metric_fns):
"""
metrics = calculate_all_metrics(fake_result)
np.testing.assert_allclose(metrics["focal_pressure"]["value"], 8.2e5)
assert metrics["focal_position"]["value"] == (15, 24)
assert metrics["focal_volume"]["value"] == 123
np.testing.assert_allclose(metrics["focal_gain"]["value"], 2.4)
assert metrics["FWHM_x"]["value"] == 6
Expand Down

0 comments on commit 6c0e074

Please sign in to comment.