Skip to content

Commit

Permalink
feat: add test step to supervised approach (#47)
Browse files Browse the repository at this point in the history
* feat: add test step for adaption to supervised

* feat: make test and val possible with and without adaption data module

* fix: linting issues
  • Loading branch information
tilman151 authored Aug 31, 2023
1 parent 8077381 commit f5ca137
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 9 deletions.
65 changes: 58 additions & 7 deletions rul_adapt/approach/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
```
"""

from typing import Tuple, Literal, Any, Dict
from typing import Literal, Any, Dict, List

import torch
import torchmetrics

from rul_adapt import utils
from rul_adapt.approach.abstract import AdaptionApproach
from rul_adapt.approach.evaluation import filter_batch
from rul_adapt.approach.evaluation import filter_batch, AdaptionEvaluator


class SupervisedApproach(AdaptionApproach):
Expand All @@ -39,6 +39,7 @@ def __init__(
self,
loss_type: Literal["mse", "mae", "rmse"],
rul_scale: int = 1,
rul_score_mode: Literal["phm08", "phm12"] = "phm08",
evaluate_degraded_only: bool = False,
**optim_kwargs: Any,
) -> None:
Expand All @@ -52,6 +53,7 @@ def __init__(
see [here][rul_adapt.utils.OptimizerFactory].
Args:
rul_score_mode:
loss_type: Training loss function to use. Either 'mse', 'mae' or 'rmse'.
rul_scale: Scalar to multiply the RUL prediction with.
evaluate_degraded_only: Whether to only evaluate the RUL score on degraded
Expand All @@ -62,12 +64,17 @@ def __init__(

self.loss_type = loss_type
self.rul_scale = rul_scale
self.rul_score_mode = rul_score_mode
self.evaluate_degraded_only = evaluate_degraded_only
self.optim_kwargs = optim_kwargs

self.train_loss = utils.get_loss(loss_type)
self._get_optimizer = utils.OptimizerFactory(**self.optim_kwargs)
self.val_loss = torchmetrics.MeanSquaredError(squared=False)
self.test_loss = torchmetrics.MeanSquaredError(squared=False)
self.evaluator = AdaptionEvaluator(
self.forward, self.log, self.rul_score_mode, self.evaluate_degraded_only
)

self.save_hyperparameters()

Expand All @@ -77,9 +84,7 @@ def configure_optimizers(self) -> Dict[str, Any]:
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.regressor(self.feature_extractor(inputs)) * self.rul_scale

def training_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
def training_step(self, batch: List[torch.Tensor], batch_idx: int) -> torch.Tensor:
"""
Execute one training step.
Expand All @@ -101,7 +106,10 @@ def training_step(
return loss

def validation_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
self,
batch: List[torch.Tensor],
batch_idx: int,
dataloader_idx: int = -1,
) -> None:
"""
Execute one validation step.
Expand All @@ -114,7 +122,50 @@ def validation_step(
batch: A list of feature and label tensors.
batch_idx: The index of the current batch.
"""
inputs, labels = filter_batch(*batch, degraded_only=self.evaluate_degraded_only)
if dataloader_idx == -1:
self._no_adapt_validation_step(batch)
else:
domain = utils.dataloader2domain(dataloader_idx)
self.evaluator.validation(batch, domain)

def _no_adapt_validation_step(self, batch: List[torch.Tensor]) -> None:
inputs, labels = batch
inputs, labels = filter_batch(
inputs, labels, degraded_only=self.evaluate_degraded_only
)
predictions = self.forward(inputs)
self.val_loss(predictions, labels[:, None])
self.log("val/loss", self.val_loss)

def test_step(
self, batch: List[torch.Tensor], batch_idx: int, dataloader_idx: int = -1
) -> None:
"""
Execute one test step.
The `batch` argument is a list of two tensors representing features and
labels. A RUL prediction is made from the features and the validation RMSE
and RUL score are calculated. The metrics recorded for dataloader idx zero
are assumed to be from the source domain and for dataloader idx one from the
target domain. The metrics are written to the configured logger under the
prefix `test`.
Args:
batch: A list containing a feature and a label tensor.
batch_idx: The index of the current batch.
dataloader_idx: The index of the current dataloader (0: source, 1: target).
"""
if dataloader_idx == -1:
self._no_adapt_test_step(batch)
else:
domain = utils.dataloader2domain(dataloader_idx)
self.evaluator.test(batch, domain)

def _no_adapt_test_step(self, batch: List[torch.Tensor]) -> None:
inputs, labels = batch
inputs, labels = filter_batch(
inputs, labels, degraded_only=self.evaluate_degraded_only
)
predictions = self.forward(inputs)
self.test_loss(predictions, labels[:, None])
self.log("test/loss", self.test_loss)
59 changes: 57 additions & 2 deletions tests/test_approach/test_supervised.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from unittest import mock

import pytest
import pytorch_lightning as pl
import rul_datasets.reader
import torch
import torchmetrics
from torchmetrics import Metric

from rul_adapt import model
from rul_adapt.approach import SupervisedApproach
from tests.test_approach import utils


@pytest.fixture()
def inputs():
return torch.randn(10, 14, 20), torch.arange(10, dtype=torch.float)
return [torch.randn(10, 1, 20), torch.arange(10, dtype=torch.float)]


@pytest.fixture()
def models():
fe = model.LstmExtractor(14, [32, 32, 32], bidirectional=True)
fe = model.LstmExtractor(1, [32, 32, 32], bidirectional=True)
reg = model.FullyConnectedHead(64, [32, 1], act_func_on_last_layer=False)

return fe, reg
Expand All @@ -30,6 +33,25 @@ def approach(models):
return approach


@pytest.fixture()
def dummy_dm():
source = rul_datasets.RulDataModule(rul_datasets.reader.DummyReader(1), 32)
target = rul_datasets.RulDataModule(rul_datasets.reader.DummyReader(2), 32)
dm = rul_datasets.DomainAdaptionDataModule(source, target)

return dm


@pytest.fixture()
def silent_trainer():
return pl.Trainer(
logger=False,
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
)


class TestSupervisedApproach:
@pytest.mark.parametrize(
["loss_type", "exp_loss", "squared"],
Expand Down Expand Up @@ -100,3 +122,36 @@ def test_val_step_logging(self, inputs, approach):

approach.val_loss.assert_called_once()
approach.log.assert_called_with("val/loss", approach.val_loss)

@torch.no_grad()
def test_test_step(self, approach, inputs):
approach.test_step(inputs, batch_idx=0, dataloader_idx=0)
approach.test_step(inputs, batch_idx=0, dataloader_idx=1)
with pytest.raises(RuntimeError):
approach.test_step(inputs, batch_idx=0, dataloader_idx=2)

@torch.no_grad()
def test_test_step_logging(self, approach, mocker):
utils.check_test_logging(approach, mocker)

def test_validation_switcher(self, approach, dummy_dm, silent_trainer, mocker):
mock_no_adapt_val = mocker.patch.object(approach, "_no_adapt_validation_step")
mock_adapt_val = mocker.patch.object(approach.evaluator, "validation")

silent_trainer.validate(approach, dummy_dm.source)
mock_no_adapt_val.assert_called()
mock_adapt_val.assert_not_called()

silent_trainer.validate(approach, dummy_dm)
mock_adapt_val.assert_called()

def test_test_switcher(self, approach, dummy_dm, silent_trainer, mocker):
mock_no_adapt_test = mocker.patch.object(approach, "_no_adapt_test_step")
mock_adapt_test = mocker.patch.object(approach.evaluator, "test")

silent_trainer.test(approach, dummy_dm.source)
mock_no_adapt_test.assert_called()
mock_adapt_test.assert_not_called()

silent_trainer.test(approach, dummy_dm)
mock_adapt_test.assert_called()

0 comments on commit f5ca137

Please sign in to comment.