Skip to content

Commit

Permalink
fix: Corrected buffering lower limit and max reinforcement for cluste…
Browse files Browse the repository at this point in the history
…ring. (#222)

* fix: corrected wrong upper limit for buffering; corrected max index for clustering

* test: Adapted failing tests

* test: Added test to cover upper_limit bug

* docs: Added missing docstring

* test: Parametrized issues test

* chore: Update koswat/strategies/order_strategy/order_strategy_buffering.py

Co-authored-by: Ardt Klapwijk <[email protected]>

---------

Co-authored-by: Ardt Klapwijk <[email protected]>
  • Loading branch information
Carsopre and ArdtK authored Nov 11, 2024
1 parent b672c7b commit b05ad5e
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 54 deletions.
10 changes: 6 additions & 4 deletions koswat/strategies/order_strategy/order_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@ def apply_strategy(self, strategy_input: StrategyInput) -> StrategyOutput:
_strategy_reinforcements = self.get_strategy_reinforcements(
strategy_input.strategy_locations, self.reinforcement_order
)
OrderStrategyBuffering.with_strategy(
self.reinforcement_order, strategy_input.reinforcement_min_buffer
OrderStrategyBuffering(
reinforcement_order=self.reinforcement_order,
reinforcement_min_buffer=strategy_input.reinforcement_min_buffer,
).apply(_strategy_reinforcements)
OrderStrategyClustering.with_strategy(
self.reinforcement_order, strategy_input.reinforcement_min_length
OrderStrategyClustering(
reinforcement_order=self.reinforcement_order,
reinforcement_min_length=strategy_input.reinforcement_min_length,
).apply(_strategy_reinforcements)
return StrategyOutput(
location_reinforcements=_strategy_reinforcements,
Expand Down
23 changes: 11 additions & 12 deletions koswat/strategies/order_strategy/order_strategy_buffering.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import dataclass

from koswat.dike_reinforcements.reinforcement_profile.reinforcement_profile_protocol import (
ReinforcementProfileProtocol,
)
Expand All @@ -8,21 +10,18 @@
from koswat.strategies.strategy_step.strategy_step_enum import StrategyStepEnum


@dataclass
class OrderStrategyBuffering(OrderStrategyBase):
"""
Applies buffering, through masks, to each location's pre-assigned reinforcement.
The result of the `apply` method will be the locations with the best
reinforcement fit (lowest index from `reinforcement_order`) that fulfills the
`reinforcement_min_buffer` requirement.
"""

reinforcement_order: list[type[ReinforcementProfileProtocol]]
reinforcement_min_buffer: float

@classmethod
def with_strategy(
cls,
reinforcement_order: list[type[ReinforcementProfileProtocol]],
reinforcement_min_buffer: float,
):
_this = cls()
_this.reinforcement_order = reinforcement_order
_this.reinforcement_min_buffer = reinforcement_min_buffer
return _this

def _get_buffer_mask(
self, location_reinforcements: list[StrategyLocationReinforcement]
) -> list[int]:
Expand All @@ -42,7 +41,7 @@ def _get_buffer_mask(
_upper_limit = int(
min(
_new_visited + self.reinforcement_min_buffer,
_len_location_reinforcements - 1,
_len_location_reinforcements,
)
)

Expand Down
23 changes: 11 additions & 12 deletions koswat/strategies/order_strategy/order_strategy_clustering.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from dataclasses import dataclass

from koswat.dike_reinforcements.reinforcement_profile.reinforcement_profile_protocol import (
ReinforcementProfileProtocol,
Expand All @@ -10,21 +11,19 @@
)


@dataclass
class OrderStrategyClustering(OrderStrategyBase):
"""
Applies clustering, to the whole collection of reinforcements
(`location_reinforcements: list[StrategyLocationReinforcement]`).
The result of the `apply` method will be the locations with the best
reinforcement fit (lowest index from `reinforcement_order`) that fulfills the
`reinforcement_min_length` requirement.
"""

reinforcement_order: list[type[ReinforcementProfileProtocol]]
reinforcement_min_length: float

@classmethod
def with_strategy(
cls,
reinforcement_order: list[type[ReinforcementProfileProtocol]],
reinforcement_min_length: float,
):
_this = cls()
_this.reinforcement_order = reinforcement_order
_this.reinforcement_min_length = reinforcement_min_length
return _this

def _get_reinforcement_order_clusters(
self,
location_reinforcements: list[StrategyLocationReinforcement],
Expand Down Expand Up @@ -60,7 +59,7 @@ def apply(
_available_clusters = self._get_reinforcement_order_clusters(
location_reinforcements
)
_reinforcements_order_max_idx = len(self.reinforcement_order)
_reinforcements_order_max_idx = len(self.reinforcement_order) - 1
for _target_idx, _reinforcement_type in enumerate(
self.reinforcement_order[:-1]
):
Expand Down
43 changes: 34 additions & 9 deletions tests/strategies/order_strategy/test_order_strategy_buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

class TestOrderStrategyBuffering:
def test_initialize(self):
_strategy = OrderStrategyBuffering()
_strategy = OrderStrategyBuffering(
reinforcement_order=[],
reinforcement_min_buffer=float("nan"),
)
assert isinstance(_strategy, OrderStrategyBuffering)
assert isinstance(_strategy, OrderStrategyBase)

Expand All @@ -19,11 +22,11 @@ def test_apply_given_docs_example(self, example_strategy_input: StrategyInput):
example_strategy_input.strategy_locations,
_reinforcement_order,
)
_strategy = OrderStrategyBuffering()
_strategy.reinforcement_order = _reinforcement_order
_strategy.reinforcement_min_buffer = (
example_strategy_input.reinforcement_min_buffer
_strategy = OrderStrategyBuffering(
reinforcement_order=_reinforcement_order,
reinforcement_min_buffer=example_strategy_input.reinforcement_min_buffer,
)

_expected_result_idx = [0, 0, 3, 3, 3, 3, 0, 4, 4, 4]
_expected_result = list(
map(lambda x: _reinforcement_order[x], _expected_result_idx)
Expand All @@ -47,14 +50,36 @@ def test__get_buffer_mask_given_docs_example(
example_strategy_input.strategy_locations,
_order_reinforcement,
)
_strategy = OrderStrategyBuffering()
_strategy.reinforcement_order = _order_reinforcement
_strategy.reinforcement_min_buffer = (
example_strategy_input.reinforcement_min_buffer
_strategy = OrderStrategyBuffering(
reinforcement_order=_order_reinforcement,
reinforcement_min_buffer=example_strategy_input.reinforcement_min_buffer,
)

# 2. Run test.
_mask_result = _strategy._get_buffer_mask(_reinforcements)

# 3. Verify expectations.
assert _mask_result == [0, 0, 3, 3, 3, 3, 0, 4, 4, 4]

def test__get_modified_example_last_location_gets_buffered(
self, example_strategy_input: StrategyInput
):
"""
This test fixes the problem related to Koswat #220.
"""
# 1. Define test data.
_order_reinforcement = OrderStrategy.get_default_order_for_reinforcements()
_reinforcements = OrderStrategy.get_strategy_reinforcements(
example_strategy_input.strategy_locations,
_order_reinforcement,
)[:6]
_strategy = OrderStrategyBuffering(
reinforcement_order=_order_reinforcement,
reinforcement_min_buffer=3,
)

# 2. Run test.
_mask_result = _strategy._get_buffer_mask(_reinforcements)

# 3. Verify expectations.
assert _mask_result == [3, 3, 3, 3, 3, 3]
29 changes: 12 additions & 17 deletions tests/strategies/order_strategy/test_order_strategy_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

class TestOrderStrategyClustering:
def test_initialize(self):
_strategy = OrderStrategyClustering()
_strategy = OrderStrategyClustering(
reinforcement_min_length=float("nan"), reinforcement_order=[]
)
assert isinstance(_strategy, OrderStrategyClustering)
assert isinstance(_strategy, OrderStrategyBase)

Expand All @@ -33,12 +35,9 @@ def test_apply_given_example_docs(
],
):
# 1. Define test data.
_strategy = OrderStrategyClustering()
_strategy.reinforcement_min_length = (
example_strategy_input.reinforcement_min_length
)
_strategy.reinforcement_order = (
OrderStrategy.get_default_order_for_reinforcements()
_strategy = OrderStrategyClustering(
reinforcement_min_length=example_strategy_input.reinforcement_min_length,
reinforcement_order=OrderStrategy.get_default_order_for_reinforcements(),
)

# 2. Run test.
Expand All @@ -65,10 +64,9 @@ def test_apply_given_cluster_with_lower_type(
],
):
# 1. Define test data.
_strategy = OrderStrategyClustering()
_strategy.reinforcement_min_length = 2
_strategy.reinforcement_order = (
OrderStrategy.get_default_order_for_reinforcements()
_strategy = OrderStrategyClustering(
reinforcement_min_length=2,
reinforcement_order=OrderStrategy.get_default_order_for_reinforcements(),
)

_location_reinforcements = example_location_reinforcements_with_buffering
Expand Down Expand Up @@ -107,12 +105,9 @@ def test__get_reinforcement_order_clusters(
],
):
# 1. Define test data.
_strategy = OrderStrategyClustering()
_strategy.reinforcement_min_length = (
example_strategy_input.reinforcement_min_length
)
_strategy.reinforcement_order = (
OrderStrategy.get_default_order_for_reinforcements()
_strategy = OrderStrategyClustering(
reinforcement_min_length=example_strategy_input.reinforcement_min_length,
reinforcement_order=OrderStrategy.get_default_order_for_reinforcements(),
)

# 2. Run test.
Expand Down
40 changes: 40 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import shutil
from pathlib import Path

import pytest
from click.testing import CliRunner

from koswat import __main__
from tests import test_data, test_results

issues_tests = test_data.joinpath("issues")


class TestMain:
def test_given_invalid_path_raises_value_error(self):
Expand Down Expand Up @@ -47,3 +50,40 @@ def test_given_valid_input_succeeds(self):
assert (
_log.read_text().find("ERROR") == -1
), "ERROR found in the log, run was not succesful."

@pytest.mark.skipif(
not any(issues_tests.glob("*")),
reason="Only meant to run locally with issue cases.",
)
@pytest.mark.parametrize(
"ini_file_location",
[
pytest.param(
issues_tests.joinpath("KOSWAT_220", "KOSWAT_analyse_RaLi.ini"),
id="Koswat 220",
)
],
)
def test_given_issue_case(self, ini_file_location: Path):
# 1. Define test data.
assert ini_file_location.is_file()
# Ensure we have a clean results dir.
_log_dir = ini_file_location.parent.joinpath("log_output")
if _log_dir.exists():
shutil.rmtree(_log_dir)
_log_dir.mkdir(parents=True)

_cli_arg = f'--input_file "{ini_file_location}" --log_output "{_log_dir}"'

# 2. Run test.
_run_result = CliRunner().invoke(
__main__.run_analysis,
_cli_arg,
)
# 3. Verify final expectations.
assert _run_result.exit_code == 0
_log: Path = next(_log_dir.glob("*.log"), None)
assert _log and _log.is_file(), "Log file was not generated."
assert (
_log.read_text().find("ERROR") == -1
), "ERROR found in the log, run was not succesful."

0 comments on commit b05ad5e

Please sign in to comment.