From 5094a8287679d7f68b917444a084d44686391f8d Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 26 Jan 2024 16:58:55 +0100 Subject: [PATCH] fix: make two-stage extractor work with empty batches (#56) --- rul_adapt/model/two_stage.py | 3 ++- tests/test_model/test_two_stage.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/rul_adapt/model/two_stage.py b/rul_adapt/model/two_stage.py index 949632c..6dfd2d4 100644 --- a/rul_adapt/model/two_stage.py +++ b/rul_adapt/model/two_stage.py @@ -52,7 +52,8 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: batch_size, upper_seq_len, input_channels, lower_seq_len = inputs.shape inputs = inputs.reshape(-1, input_channels, lower_seq_len) inputs = self.lower_stage(inputs) - inputs = inputs.reshape(batch_size, upper_seq_len, -1) + _, lower_output_units = inputs.shape + inputs = inputs.reshape(batch_size, upper_seq_len, lower_output_units) inputs = torch.transpose(inputs, 1, 2) inputs = self.upper_stage(inputs) diff --git a/tests/test_model/test_two_stage.py b/tests/test_model/test_two_stage.py index 186b351..03a0d0b 100644 --- a/tests/test_model/test_two_stage.py +++ b/tests/test_model/test_two_stage.py @@ -44,3 +44,7 @@ def test_forward_upper_lower_interaction(inputs, extractor): outputs = extractor(inputs) assert torch.allclose(upper_outputs, outputs[3]) + + +def test_forward_with_empty_input(extractor): + output = extractor(torch.empty(0, 4, 3, 64))