Skip to content

Commit

Permalink
Cleaning up scenario __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
NewtonSander committed Jul 24, 2023
1 parent 0a8c938 commit 8ab282c
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 68 deletions.
12 changes: 6 additions & 6 deletions docs/examples/plot_full_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ class FullScenario(Scenario2D):
num_points=1000,
)
]

def __init__(self, scenario_id: str, material_outline_upsample_factor: int = 8):
super().__init__(scenario_id, material_outline_upsample_factor)
material_outline_upsample_factor = 8

def compile_problem(self, center_frequency) -> stride.Problem:
"""The problem definition for the scenario."""
Expand Down Expand Up @@ -149,15 +147,17 @@ def _fill_mask(mask, start, end, dx):


# %%
# ## Creating the scenario
scenario = FullScenario(FullScenario.scenario_id)
# ## Running the scenario

scenario = FullScenario()

# %%
# ## Rendering the scenario layout
scenario.render_layout()

# %%
# ## Running the scenario
# ## Rendering the simulation
scenario.compile_problem(center_frequency=5e5)
result = scenario.simulate_steady_state()
assert isinstance(result, SteadyStateResult2D)
result.render_steady_state_amplitudes(show_material_outlines=False)
Expand Down
2 changes: 1 addition & 1 deletion src/neurotechdevkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def make(scenario_id: str) -> scenarios.Scenario:
f"Scenario '{scenario_id}' does not exist. Please refer to documentation"
" for the list of provided scenarios."
)
return _scenario_map[scenario_id](scenario_id=scenario_id) # type: ignore
return _scenario_map[scenario_id]() # type: ignore


_scenario_map = {
Expand Down
18 changes: 1 addition & 17 deletions src/neurotechdevkit/scenarios/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,9 @@ class Scenario(abc.ABC):
target: Target

scenario_id: str
material_outline_upsample_factor: int
slice_axis: int
slice_position: float

def __init__(
self,
scenario_id: str,
material_outline_upsample_factor: int = 16,
):
"""
Initialize a new scenario.
Args:
scenario_id (str): An identifier for the scenario.
material_outline_upsample_factor (int, optional): The factor by which to
upsample the material outline. Defaults to 16.
"""
self.scenario_id = scenario_id
self.material_outline_upsample_factor = material_outline_upsample_factor
material_outline_upsample_factor: int = 16

def render_layout(
self,
Expand Down
16 changes: 2 additions & 14 deletions src/neurotechdevkit/scenarios/_scenario_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,7 @@ class Scenario1_2D(Scenario1, Scenario2D):
)
]
origin = np.array([0.0, -0.035])

def __init__(self, scenario_id: str, material_outline_upsample_factor: int = 8):
"""Instantiate Scenario1 with overwritten material_outline_upsample_factor."""
super().__init__(
scenario_id=scenario_id,
material_outline_upsample_factor=material_outline_upsample_factor,
)
material_outline_upsample_factor = 8

def compile_problem(self, center_frequency: float) -> stride.Problem:
"""
Expand Down Expand Up @@ -175,13 +169,7 @@ class Scenario1_3D(Scenario1, Scenario3D):
)
slice_axis = 1
slice_position = 0.0

def __init__(self, scenario_id, material_outline_upsample_factor: int = 8):
"""Instantiate Scenario1 with overwritten material_outline_upsample_factor."""
super().__init__(
scenario_id=scenario_id,
material_outline_upsample_factor=material_outline_upsample_factor,
)
material_outline_upsample_factor = 8

def compile_problem(self, center_frequency: float) -> stride.Problem:
"""
Expand Down
16 changes: 2 additions & 14 deletions src/neurotechdevkit/scenarios/_scenario_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,7 @@ class Scenario2_2D(Scenario2, Scenario2D):
num_points=1000,
)
]

def __init__(self, scenario_id, material_outline_upsample_factor: int = 4):
"""Instantiate Scenario2 with overwritten material_outline_upsample_factor."""
super().__init__(
scenario_id=scenario_id,
material_outline_upsample_factor=material_outline_upsample_factor,
)
material_outline_upsample_factor = 4

def compile_problem(self, center_frequency: float) -> stride.Problem:
"""
Expand Down Expand Up @@ -183,13 +177,7 @@ class Scenario2_3D(Scenario2, Scenario3D):
)
slice_axis = 2
slice_position = 0.0

def __init__(self, scenario_id, material_outline_upsample_factor: int = 4):
"""Instantiate Scenario2 with overwritten material_outline_upsample_factor."""
super().__init__(
scenario_id=scenario_id,
material_outline_upsample_factor=material_outline_upsample_factor,
)
material_outline_upsample_factor = 4

def compile_problem(self, center_frequency: float) -> stride.Problem:
"""
Expand Down
3 changes: 0 additions & 3 deletions tests/neurotechdevkit/scenarios/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ class ScenarioBaseTester(Scenario):

def __init__(self):
self.problem = self._compile_problem(center_frequency=5e5)
super().__init__(
scenario_id=self.scenario_id, material_outline_upsample_factor=3
)

def _compile_problem(self, center_frequency: float) -> stride.Problem:
extent = np.array([2.0, 3.0, 4.0])
Expand Down
15 changes: 4 additions & 11 deletions tests/neurotechdevkit/scenarios/test_materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@ def compare_structs(struct1: Struct, struct2: Struct):
assert struct1.render_color == struct2.render_color


class BaseScenario(Scenario2D):
"""A scenario for testing the materials module."""

def __init__(self):
super().__init__(scenario_id="test", material_outline_upsample_factor=16)


def test_custom_material_property():
"""Test that a custom material property is used."""

class ScenarioWithCustomMaterialProperties(BaseScenario):
class ScenarioWithCustomMaterialProperties(Scenario2D):
material_layers = ["brain"]
material_properties = {
"brain": Material(vp=1600.0, rho=1100.0, alpha=0.0, render_color="#2E86AB")
Expand All @@ -43,7 +36,7 @@ class ScenarioWithCustomMaterialProperties(BaseScenario):
def test_new_material():
"""Test that a new material is used."""

class ScenarioWithCustomMaterial(BaseScenario):
class ScenarioWithCustomMaterial(Scenario2D):
material_layers = ["brain", "eye"]
material_properties = {
"eye": Material(vp=1600.0, rho=1100.0, alpha=0.0, render_color="#2E86AB")
Expand All @@ -64,7 +57,7 @@ class ScenarioWithCustomMaterial(BaseScenario):
def test_material_absorption_is_calculated():
"""Test that the material absorption is calculated for a frequency !=500e3."""

class ScenarioWithBrainMaterial(BaseScenario):
class ScenarioWithBrainMaterial(Scenario2D):
material_layers = ["brain"]
material_properties = {}

Expand All @@ -79,7 +72,7 @@ class ScenarioWithBrainMaterial(BaseScenario):
def test_unknown_material_without_properties():
"""Test that an unknown material without properties raises an error."""

class ScenarioWithCustomMaterial(BaseScenario):
class ScenarioWithCustomMaterial(Scenario2D):
material_layers = ["unknown_material"]
material_properties = {}

Expand Down
4 changes: 2 additions & 2 deletions tests/neurotechdevkit/scenarios/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def pulsed_data_2d():
@pytest.fixture
def a_test_scenario_2d():
"""A real 2D scenario that can be saved to disk and reloaded."""
scenario = scenarios.Scenario1_2D(scenario_id=scenarios.Scenario1_2D.scenario_id)
scenario = scenarios.Scenario1_2D()
scenario.add_source(
sources.FocusedSource2D(
position=np.array([0.02, 0.02]),
Expand All @@ -125,7 +125,7 @@ def a_test_scenario_2d():
@pytest.fixture
def a_test_scenario_3d():
"""A real 3D scenario that can be saved to disk and reloaded."""
scenario = scenarios.Scenario1_3D(scenario_id=scenarios.Scenario1_3D.scenario_id)
scenario = scenarios.Scenario1_3D()
scenario.add_source(
sources.FocusedSource3D(
position=np.array([0.02, 0.02, 0.0]),
Expand Down

0 comments on commit 8ab282c

Please sign in to comment.