diff --git a/koswat/strategies/order_strategy/order_strategy.py b/koswat/strategies/order_strategy/order_strategy.py index 1b3b267d..cf788e19 100644 --- a/koswat/strategies/order_strategy/order_strategy.py +++ b/koswat/strategies/order_strategy/order_strategy.py @@ -162,11 +162,13 @@ def apply_strategy(self, strategy_input: StrategyInput) -> StrategyOutput: _strategy_reinforcements = self.get_strategy_reinforcements( strategy_input.strategy_locations, self.reinforcement_order ) - OrderStrategyBuffering.with_strategy( - self.reinforcement_order, strategy_input.reinforcement_min_buffer + OrderStrategyBuffering( + reinforcement_order=self.reinforcement_order, + reinforcement_min_buffer=strategy_input.reinforcement_min_buffer, ).apply(_strategy_reinforcements) - OrderStrategyClustering.with_strategy( - self.reinforcement_order, strategy_input.reinforcement_min_length + OrderStrategyClustering( + reinforcement_order=self.reinforcement_order, + reinforcement_min_length=strategy_input.reinforcement_min_length, ).apply(_strategy_reinforcements) return StrategyOutput( location_reinforcements=_strategy_reinforcements, diff --git a/koswat/strategies/order_strategy/order_strategy_buffering.py b/koswat/strategies/order_strategy/order_strategy_buffering.py index 4c8b193c..48ffe204 100644 --- a/koswat/strategies/order_strategy/order_strategy_buffering.py +++ b/koswat/strategies/order_strategy/order_strategy_buffering.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass + from koswat.dike_reinforcements.reinforcement_profile.reinforcement_profile_protocol import ( ReinforcementProfileProtocol, ) @@ -8,21 +10,18 @@ from koswat.strategies.strategy_step.strategy_step_enum import StrategyStepEnum +@dataclass class OrderStrategyBuffering(OrderStrategyBase): + """ + Applies buffering, through masks, to each location's pre-assigned reinforcement. + The result of the `apply` method will be the locations with the best + reinforcement fit (lowest index from `reinforcement_order`) that fulfills the + `reinforcement_min_buffer` requirement. + """ + reinforcement_order: list[type[ReinforcementProfileProtocol]] reinforcement_min_buffer: float - @classmethod - def with_strategy( - cls, - reinforcement_order: list[type[ReinforcementProfileProtocol]], - reinforcement_min_buffer: float, - ): - _this = cls() - _this.reinforcement_order = reinforcement_order - _this.reinforcement_min_buffer = reinforcement_min_buffer - return _this - def _get_buffer_mask( self, location_reinforcements: list[StrategyLocationReinforcement] ) -> list[int]: @@ -42,7 +41,7 @@ def _get_buffer_mask( _upper_limit = int( min( _new_visited + self.reinforcement_min_buffer, - _len_location_reinforcements - 1, + _len_location_reinforcements, ) ) diff --git a/koswat/strategies/order_strategy/order_strategy_clustering.py b/koswat/strategies/order_strategy/order_strategy_clustering.py index ed330920..c3c5c75e 100644 --- a/koswat/strategies/order_strategy/order_strategy_clustering.py +++ b/koswat/strategies/order_strategy/order_strategy_clustering.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from koswat.dike_reinforcements.reinforcement_profile.reinforcement_profile_protocol import ( ReinforcementProfileProtocol, @@ -10,21 +11,19 @@ ) +@dataclass class OrderStrategyClustering(OrderStrategyBase): + """ + Applies clustering, to the whole collection of reinforcements + (`location_reinforcements: list[StrategyLocationReinforcement]`). + The result of the `apply` method will be the locations with the best + reinforcement fit (lowest index from `reinforcement_order`) that fulfills the + `reinforcement_min_length` requirement. + """ + reinforcement_order: list[type[ReinforcementProfileProtocol]] reinforcement_min_length: float - @classmethod - def with_strategy( - cls, - reinforcement_order: list[type[ReinforcementProfileProtocol]], - reinforcement_min_length: float, - ): - _this = cls() - _this.reinforcement_order = reinforcement_order - _this.reinforcement_min_length = reinforcement_min_length - return _this - def _get_reinforcement_order_clusters( self, location_reinforcements: list[StrategyLocationReinforcement], @@ -60,7 +59,7 @@ def apply( _available_clusters = self._get_reinforcement_order_clusters( location_reinforcements ) - _reinforcements_order_max_idx = len(self.reinforcement_order) + _reinforcements_order_max_idx = len(self.reinforcement_order) - 1 for _target_idx, _reinforcement_type in enumerate( self.reinforcement_order[:-1] ): diff --git a/tests/strategies/order_strategy/test_order_strategy_buffering.py b/tests/strategies/order_strategy/test_order_strategy_buffering.py index 368eb66c..7da17cde 100644 --- a/tests/strategies/order_strategy/test_order_strategy_buffering.py +++ b/tests/strategies/order_strategy/test_order_strategy_buffering.py @@ -8,7 +8,10 @@ class TestOrderStrategyBuffering: def test_initialize(self): - _strategy = OrderStrategyBuffering() + _strategy = OrderStrategyBuffering( + reinforcement_order=[], + reinforcement_min_buffer=float("nan"), + ) assert isinstance(_strategy, OrderStrategyBuffering) assert isinstance(_strategy, OrderStrategyBase) @@ -19,11 +22,11 @@ def test_apply_given_docs_example(self, example_strategy_input: StrategyInput): example_strategy_input.strategy_locations, _reinforcement_order, ) - _strategy = OrderStrategyBuffering() - _strategy.reinforcement_order = _reinforcement_order - _strategy.reinforcement_min_buffer = ( - example_strategy_input.reinforcement_min_buffer + _strategy = OrderStrategyBuffering( + reinforcement_order=_reinforcement_order, + reinforcement_min_buffer=example_strategy_input.reinforcement_min_buffer, ) + _expected_result_idx = [0, 0, 3, 3, 3, 3, 0, 4, 4, 4] _expected_result = list( map(lambda x: _reinforcement_order[x], _expected_result_idx) @@ -47,10 +50,9 @@ def test__get_buffer_mask_given_docs_example( example_strategy_input.strategy_locations, _order_reinforcement, ) - _strategy = OrderStrategyBuffering() - _strategy.reinforcement_order = _order_reinforcement - _strategy.reinforcement_min_buffer = ( - example_strategy_input.reinforcement_min_buffer + _strategy = OrderStrategyBuffering( + reinforcement_order=_order_reinforcement, + reinforcement_min_buffer=example_strategy_input.reinforcement_min_buffer, ) # 2. Run test. @@ -58,3 +60,26 @@ def test__get_buffer_mask_given_docs_example( # 3. Verify expectations. assert _mask_result == [0, 0, 3, 3, 3, 3, 0, 4, 4, 4] + + def test__get_modified_example_last_location_gets_buffered( + self, example_strategy_input: StrategyInput + ): + """ + This test fixes the problem related to Koswat #220. + """ + # 1. Define test data. + _order_reinforcement = OrderStrategy.get_default_order_for_reinforcements() + _reinforcements = OrderStrategy.get_strategy_reinforcements( + example_strategy_input.strategy_locations, + _order_reinforcement, + )[:6] + _strategy = OrderStrategyBuffering( + reinforcement_order=_order_reinforcement, + reinforcement_min_buffer=3, + ) + + # 2. Run test. + _mask_result = _strategy._get_buffer_mask(_reinforcements) + + # 3. Verify expectations. + assert _mask_result == [3, 3, 3, 3, 3, 3] diff --git a/tests/strategies/order_strategy/test_order_strategy_clustering.py b/tests/strategies/order_strategy/test_order_strategy_clustering.py index e045be47..66ddbb2e 100644 --- a/tests/strategies/order_strategy/test_order_strategy_clustering.py +++ b/tests/strategies/order_strategy/test_order_strategy_clustering.py @@ -21,7 +21,9 @@ class TestOrderStrategyClustering: def test_initialize(self): - _strategy = OrderStrategyClustering() + _strategy = OrderStrategyClustering( + reinforcement_min_length=float("nan"), reinforcement_order=[] + ) assert isinstance(_strategy, OrderStrategyClustering) assert isinstance(_strategy, OrderStrategyBase) @@ -33,12 +35,9 @@ def test_apply_given_example_docs( ], ): # 1. Define test data. - _strategy = OrderStrategyClustering() - _strategy.reinforcement_min_length = ( - example_strategy_input.reinforcement_min_length - ) - _strategy.reinforcement_order = ( - OrderStrategy.get_default_order_for_reinforcements() + _strategy = OrderStrategyClustering( + reinforcement_min_length=example_strategy_input.reinforcement_min_length, + reinforcement_order=OrderStrategy.get_default_order_for_reinforcements(), ) # 2. Run test. @@ -65,10 +64,9 @@ def test_apply_given_cluster_with_lower_type( ], ): # 1. Define test data. - _strategy = OrderStrategyClustering() - _strategy.reinforcement_min_length = 2 - _strategy.reinforcement_order = ( - OrderStrategy.get_default_order_for_reinforcements() + _strategy = OrderStrategyClustering( + reinforcement_min_length=2, + reinforcement_order=OrderStrategy.get_default_order_for_reinforcements(), ) _location_reinforcements = example_location_reinforcements_with_buffering @@ -107,12 +105,9 @@ def test__get_reinforcement_order_clusters( ], ): # 1. Define test data. - _strategy = OrderStrategyClustering() - _strategy.reinforcement_min_length = ( - example_strategy_input.reinforcement_min_length - ) - _strategy.reinforcement_order = ( - OrderStrategy.get_default_order_for_reinforcements() + _strategy = OrderStrategyClustering( + reinforcement_min_length=example_strategy_input.reinforcement_min_length, + reinforcement_order=OrderStrategy.get_default_order_for_reinforcements(), ) # 2. Run test. diff --git a/tests/test_main.py b/tests/test_main.py index f0dd160f..f161d8da 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,11 +1,14 @@ import shutil from pathlib import Path +import pytest from click.testing import CliRunner from koswat import __main__ from tests import test_data, test_results +issues_tests = test_data.joinpath("issues") + class TestMain: def test_given_invalid_path_raises_value_error(self): @@ -47,3 +50,40 @@ def test_given_valid_input_succeeds(self): assert ( _log.read_text().find("ERROR") == -1 ), "ERROR found in the log, run was not succesful." + + @pytest.mark.skipif( + not any(issues_tests.glob("*")), + reason="Only meant to run locally with issue cases.", + ) + @pytest.mark.parametrize( + "ini_file_location", + [ + pytest.param( + issues_tests.joinpath("KOSWAT_220", "KOSWAT_analyse_RaLi.ini"), + id="Koswat 220", + ) + ], + ) + def test_given_issue_case(self, ini_file_location: Path): + # 1. Define test data. + assert ini_file_location.is_file() + # Ensure we have a clean results dir. + _log_dir = ini_file_location.parent.joinpath("log_output") + if _log_dir.exists(): + shutil.rmtree(_log_dir) + _log_dir.mkdir(parents=True) + + _cli_arg = f'--input_file "{ini_file_location}" --log_output "{_log_dir}"' + + # 2. Run test. + _run_result = CliRunner().invoke( + __main__.run_analysis, + _cli_arg, + ) + # 3. Verify final expectations. + assert _run_result.exit_code == 0 + _log: Path = next(_log_dir.glob("*.log"), None) + assert _log and _log.is_file(), "Log file was not generated." + assert ( + _log.read_text().find("ERROR") == -1 + ), "ERROR found in the log, run was not succesful."