Skip to content

Commit

Permalink
feat: add two stage extractor (#55)
Browse files Browse the repository at this point in the history
* feat: add two-stage extractor model

* docs: fix links

* chore: update RUL datasets dependency
  • Loading branch information
tilman151 authored Jan 26, 2024
1 parent 9bc8d4b commit 15f0cc8
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 8 deletions.
6 changes: 3 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ Currently, five approaches are implemented, including their original hyperparame

Three approaches are implemented without their original hyperparameters:

* **[ConditionalDANN][rul_adapt.approach.conditional]** by Cheng et al. (2021)
* **[ConditionalMMD][rul_adapt.approach.conditional]** by Cheng et al. (2021)
* **[ConditionalDANN][rul_adapt.approach.conditional.ConditionalDannApproach]** by Cheng et al. (2021)
* **[ConditionalMMD][rul_adapt.approach.conditional.ConditionalMmdApproach]** by Cheng et al. (2021)
* **[PseudoLabels][rul_adapt.approach.pseudo_labels]** as used by Wang et al. (2022)

This includes the following general approaches adapted for RUL estimation:

* **Domain Adaption Neural Networks (DANN)** by Ganin et al. (2016)
* **Multi-Kernel Maximum Mean Discrepancy (MMD)** by Long et al. (2015)

Each approach has an example notebook which can be found in the [examples](examples) folder.
Each approach has an example notebook which can be found in the [examples](examples/index.md) folder.

## Installation

Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ packages = [{include = "rul_adapt"}]
[tool.poetry.dependencies]
python = "^3.8"
pytorch-lightning = ">1.8.0.post1"
rul-datasets = ">=0.14.0"
rul-datasets = ">=0.14.1"
tqdm = "^4.62.2"
hydra-core = "^1.3.1"
pywavelets = "^1.4.1"
Expand Down
1 change: 1 addition & 0 deletions rul_adapt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .head import FullyConnectedHead
from .rnn import LstmExtractor, GruExtractor
from .wrapper import ActivationDropoutWrapper
from .two_stage import TwoStageExtractor
59 changes: 59 additions & 0 deletions rul_adapt/model/two_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from torch import nn


class TwoStageExtractor(nn.Module):
"""This module combines two feature extractors into a single network.
The input data is expected to be of shape `[batch_size, upper_seq_len,
input_channels, lower_seq_len]`. An example would be vibration data recorded in
spaced intervals, where lower_seq_len is the length of an interval and
upper_seq_len is the window size of a sliding window over the intervals.
The lower_stage is applied to each interval individually to extract features.
The upper_stage is then applied to the extracted features of the window.
The resulting feature vector should represent the window without the need to
manually extract features from the raw data of the intervals.
"""

def __init__(
self,
lower_stage: nn.Module,
upper_stage: nn.Module,
):
"""
Create a new two-stage extractor.
The lower stage needs to take a tensor of shape `[batch_size, input_channels,
seq_len]` and return a tensor of shape `[batch_size, lower_output_units]`. The
upper stage needs to take a tensor of shape `[batch_size, upper_seq_len,
lower_output_units]` and return a tensor of shape `[batch_size,
upper_output_units]`. Args: lower_stage: upper_stage:
"""
super().__init__()

self.lower_stage = lower_stage
self.upper_stage = upper_stage

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Apply the two-stage extractor to the input tensor.
The input tensor is expected to be of shape `[batch_size, upper_seq_len,
input_channels, lower_seq_len]`. The output tensor will be of shape
`[batch_size, upper_output_units]`.
Args:
inputs: the input tensor
Returns:
an output tensor of shape `[batch_size, upper_output_units]`
"""
batch_size, upper_seq_len, input_channels, lower_seq_len = inputs.shape
inputs = inputs.reshape(-1, input_channels, lower_seq_len)
inputs = self.lower_stage(inputs)
inputs = inputs.reshape(batch_size, upper_seq_len, -1)
inputs = torch.transpose(inputs, 1, 2)
inputs = self.upper_stage(inputs)

return inputs
46 changes: 46 additions & 0 deletions tests/test_model/test_two_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
import torch

from rul_adapt.model import TwoStageExtractor


@pytest.fixture()
def extractor():
lower_stage = torch.nn.Sequential(
torch.nn.Conv1d(3, 8, 3),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(8 * 62, 8),
)
upper_stage = torch.nn.Sequential(
torch.nn.Conv1d(8, 8, 2),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(8 * 3, 8),
)
extractor = TwoStageExtractor(lower_stage, upper_stage)

return extractor


@pytest.fixture()
def inputs():
return torch.rand(16, 4, 3, 64)


def test_forward_shape(inputs, extractor):
outputs = extractor(inputs)

assert outputs.shape == (16, 8)


def test_forward_upper_lower_interaction(inputs, extractor):
one_sample = inputs[3]

lower_outputs = extractor.lower_stage(one_sample)
upper_outputs = extractor.upper_stage(
torch.transpose(lower_outputs.unsqueeze(0), 1, 2)
)
outputs = extractor(inputs)

assert torch.allclose(upper_outputs, outputs[3])

0 comments on commit 15f0cc8

Please sign in to comment.