diff --git a/deeplay/models/recurrentmodel.py b/deeplay/models/recurrentmodel.py index 9aaa6b0..07ef705 100644 --- a/deeplay/models/recurrentmodel.py +++ b/deeplay/models/recurrentmodel.py @@ -167,7 +167,10 @@ def forward(self, x): x = self.embedding_dropout(x) outputs = x - outputs, hidden = super().forward(outputs) + if self.return_cell_state: + outputs, hidden = super().forward(outputs) + else: + outputs = super().forward(outputs) if self.bidirectional: outputs = (