Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] xLSTMTime implementation #1709

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open

Conversation

phoeenniixx
Copy link

@phoeenniixx phoeenniixx commented Nov 9, 2024

Description

This PR tries to implement xLSTMTime based on this paper

see also sktime issue #6793

Checklist

  • Linked issues (if existing)
  • Amended changelog for large changes (and added myself there as contributor)
  • Added/modified tests
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

@phoeenniixx phoeenniixx changed the title initial commit [ENH] xLSTMTime implementation Nov 9, 2024
Copy link

codecov bot commented Nov 9, 2024

Codecov Report

Attention: Patch coverage is 0% with 315 lines in your changes missing coverage. Please review.

Please upload report for BASE (main@a884c4d). Learn more about missing BASE report.

Files with missing lines Patch % Lines
pytorch_forecasting/models/xLSTMTime/mLSTM/cell.py 0.00% 74 Missing ⚠️
pytorch_forecasting/models/xLSTMTime/sLSTM/cell.py 0.00% 64 Missing ⚠️
pytorch_forecasting/models/xLSTMTime/xLSTMTime.py 0.00% 57 Missing ⚠️
...ytorch_forecasting/models/xLSTMTime/mLSTM/layer.py 0.00% 43 Missing ⚠️
...ytorch_forecasting/models/xLSTMTime/sLSTM/layer.py 0.00% 42 Missing ⚠️
...orch_forecasting/models/xLSTMTime/sLSTM/network.py 0.00% 20 Missing ⚠️
...orch_forecasting/models/xLSTMTime/mLSTM/network.py 0.00% 15 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1709   +/-   ##
=======================================
  Coverage        ?   84.37%           
=======================================
  Files           ?       40           
  Lines           ?     5144           
  Branches        ?        0           
=======================================
  Hits            ?     4340           
  Misses          ?      804           
  Partials        ?        0           
Flag Coverage Δ
cpu 84.37% <0.00%> (?)
pytest 84.37% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@phoeenniixx
Copy link
Author

hi @fkiraly, I am new to pytorch-forecasting and its tests and all, can you please tell me exactly what am I "missing"?

@phoeenniixx
Copy link
Author

Will these tests suffice @fkiraly?

Copy link
Collaborator

@benHeid benHeid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @phoeenniixx,
welcome to pytorch-forecasting and thank you for your pull request and contributing xlstm.
I added first comments about the BaseClass you used. Please change it to one of the BaseClasses (see the comment). Since I suppose that this will change your code a bit. I will wait with a complete review until you changed it.

return trend, seasonal


class xLSTMTime(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the Base classes of pytorch-forecasting (BaseModelWithCovariates, etc.) depending on the properties of the forecaster.
The advantage of doing this is that it automatically comes with PyTorch lightning and thus less boilerplate is needed.

You might compare it with the NHITS implementation and check how it is implemented.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ensure that the naming conventions of files are met. I.e., only lower case is allowed and use _ as a separator. between words. .../x_lstm_time/x_lstm_time.py

device: Optional[torch.device] = None,
):
"""
Initialize xLSTMTime model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check where to put the reference to the paper that originally proposes xlstm.

@phoeenniixx
Copy link
Author

Thanks for the review @benHeid!
I will have to restructure a little ig, I will see and use appropriate base class, use it in main xLSTMTime class, rest will be left untouched? (wrt to baseclass atleast)
I will make the changes and get back to you in few days!
Thanks!

@phoeenniixx
Copy link
Author

Hi @benHeid, I need some help:

  • here I implemented xLSTMTime class using BaseModel as for now I think this is the best fitted class... what do you think?

  • Also, I made some changes in the forward function of the code where before it was accepting Tensor object, I changed it to Dict as I found out that the user mainly uses TimeSeriesDataSet and it returns a dict, please correct me if I am wrong here.

  • I am using the encoder_cont key of the dict as input x.

Please tell me if I am in a right direction

class xLSTMTime(BaseModel):

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        xlstm_type: Literal['slstm', 'mlstm'],
        num_layers: int = 1,
        decomposition_kernel: int = 25,
        input_projection_size: Optional[int] = None,
        dropout: float = 0.1,
        loss: Metric = SMAPE(),
        device: Optional[torch.device] = None,
        **kwargs
    ):
        super().__init__(loss=loss, **kwargs)

        if xlstm_type not in ['slstm', 'mlstm']:
            raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'")

        self.xlstm_type = xlstm_type
        self._device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self._device)

        self.decomposition = SeriesDecomposition(decomposition_kernel)
        self.batch_norm = nn.BatchNorm1d(hidden_size)

        self.input_projection_size = input_projection_size or hidden_size

        self.input_linear = None  

        if xlstm_type == 'mlstm':
            self.lstm = mLSTMNetwork(
                input_size=hidden_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                output_size=hidden_size,
                dropout=dropout,
                device=self.device
            )
        else:  # slstm
            self.lstm = sLSTMNetwork(
                input_size=hidden_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                output_size=hidden_size,
                dropout=dropout,
                device=self.device
            )

        self.output_linear = nn.Linear(hidden_size, output_size)
        self.instance_norm = nn.InstanceNorm1d(output_size)

    def forward(
        self,
        x: Dict[str, torch.Tensor],  
        hidden_states: Optional[
            Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
        ] = None
    ) -> Tuple[torch.Tensor, Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
   
        encoder_cont = x["encoder_cont"]
        batch_size, seq_len, n_features = encoder_cont.shape

        trend, seasonal = self.decomposition(encoder_cont)

        x = torch.cat([trend, seasonal], dim=-1)
        concatenated_features = x.shape[-1]

        if self.input_linear is None:
            self.input_linear = nn.Linear(concatenated_features, self.input_projection_size).to(self._device)

        x = self.input_linear(x)

        x = x.transpose(1, 2)  
        x = self.batch_norm(x)
        x = x.transpose(1, 2)  

        if hidden_states is None:
            hidden_states = self.lstm.init_hidden(batch_size)

        x = x.transpose(0, 1)
        output, hidden_states = self.lstm(x, *hidden_states)

        if isinstance(output, tuple):
            output = output[0]

        if output.dim() == 2:
            output = output.unsqueeze(0)
        output = self.output_linear(output)

        output = output.transpose(1, 2)
        output = self.instance_norm(output)
        output = output.transpose(1, 2)

        return output, hidden_states


    def predict(
            self,
            x: torch.Tensor,
            hidden_states: Optional[
                Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
            ] = None
    ) -> torch.Tensor:

        output, _ = self.forward(x, hidden_states)
        return output

    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y[0] if isinstance(y, tuple) else y 

        y_pred, _ = self(x)

        if y_pred.ndim == 3 and y_pred.size(0) == 1:
            y_pred = y_pred.squeeze(0)  
        loss = self.loss(y_pred, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y[0] if isinstance(y, tuple) else y 

        y_pred, _ = self(x)

        if y_pred.ndim == 3 and y_pred.size(0) == 1:
            y_pred = y_pred.squeeze(0)  
        loss = self.loss(y_pred, y)
        self.log("val_loss", loss)

        return loss




    def test_step(self, batch, batch_idx):
        x, y = batch
        y = y[0] if isinstance(y, tuple) else y 

        y_pred, _ = self(x)

        if y_pred.ndim == 3 and y_pred.size(0) == 1:
            y_pred = y_pred.squeeze(0)  
        loss = self.loss(y_pred, y)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=10)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

@phoeenniixx
Copy link
Author

Also, Do we need to change the baseclass of just xLSTMTime only or mLSTMNetwork and sLSTMNetwork should also be changed?
(Although I think they are just a part of this main class so they could inherit from nn.Module without any problem?)

@benHeid
Copy link
Collaborator

benHeid commented Dec 10, 2024

  • here I implemented xLSTMTime class using BaseModel as for now I think this is the best fitted class... what do you think?

Mhm. if the implementation does not support any exogenous features than either BaseModel or AutoRegressiveBaseModel. I would assume that the ladder is probably the better fit.

  • Also, I made some changes in the forward function of the code where before it was accepting Tensor object, I changed it to Dict as I found out that the user mainly uses TimeSeriesDataSet and it returns a dict, please correct me if I am wrong here.

I agree that a dict should be used here.

  • I am using the encoder_cont key of the dict as input x.

Yes that is the target time series.

Please tell me if I am in a right direction

You might check the RNN implementation. Since this is also inheriting from an Autoregressive model and probably the most similar of the implemented models.
I would suggest that you check carefully, if you really need to implement the step / training_step method etc. or if is sufficient to use the inherited methods from the base class.

But I think you are in the right direction.

@phoeenniixx
Copy link
Author

phoeenniixx commented Dec 12, 2024

Hi @benHeid, I have updated the implementation using AutoRegressiveBaseModel, please review it. Also, I have not changed or added the tests (they are failing due to some changes in input and output format) as I saw that for other modules, there is a specific "trend" of writing the tests and I might need some help with that. Can you please provide me a brief about them, like what specific tests should I add etc.

I can add the docstrings in subsequent commits once I am sure that this is what we want.

@benHeid
Copy link
Collaborator

benHeid commented Dec 24, 2024

Sorry for my late response. Please ensure that the linting tests are green. Probably running the pre commit hooks locally should make it.

Regarding the failing tests, you might check how the output currently looks like by manually executing the xLSTM. You might then see what the issue is.

@fkiraly do we have any guides for pytorch-forecasting on how to write tests?

@phoeenniixx
Copy link
Author

Thanks for the reply @benHeid, actually the reason the tests are failing is: earlier I was using tensors, tuple etc and now TimeSeriesDataset is being used that uses a dict, that is the reason the tests are failing, I can correct those but I didn't do that because I noticed that for other models, they just use functions like test_integration etc. To write those functions, I first need to understand the input like dataloaders, dataset that is entered in these functions, like which data we are using here, the labels etc. is that data any arbitrary data or some pre-defined dataset?
Like look into this function from test_models.test_rnn_model,py:

def _integration(
    data_with_covariates, tmp_path, cell_type="LSTM", data_loader_kwargs={}, clip_target: bool = False, **kwargs
):
    data_with_covariates = data_with_covariates.copy()
    if clip_target:
        data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0)
    else:
        data_with_covariates["target"] = data_with_covariates["volume"]
    data_loader_default_kwargs = dict(
        target="target",
        time_varying_known_reals=["price_actual"],
        time_varying_unknown_reals=["target"],
        static_categoricals=["agency"],
        add_relative_time_idx=True,
    )
    data_loader_default_kwargs.update(data_loader_kwargs)
    dataloaders_with_covariates = make_dataloaders(data_with_covariates, **data_loader_default_kwargs)
    train_dataloader = dataloaders_with_covariates["train"]
    val_dataloader = dataloaders_with_covariates["val"]
    test_dataloader = dataloaders_with_covariates["test"]

    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")

    logger = TensorBoardLogger(tmp_path)
    trainer = pl.Trainer(
        max_epochs=3,
        gradient_clip_val=0.1,
        callbacks=[early_stop_callback],
        enable_checkpointing=True,
        default_root_dir=tmp_path,
        limit_train_batches=2,
        limit_val_batches=2,
        limit_test_batches=2,
        logger=logger,
    )

    net = RecurrentNetwork.from_dataset(
        train_dataloader.dataset,
        cell_type=cell_type,
        learning_rate=0.15,
        log_gradient_flow=True,
        log_interval=1000,
        hidden_size=5,
        **kwargs,
    )
    net.size()
    try:
        trainer.fit(
            net,
            train_dataloaders=train_dataloader,
            val_dataloaders=val_dataloader,
        )
        test_outputs = trainer.test(net, dataloaders=test_dataloader)
        assert len(test_outputs) > 0
        # check loading
        net = RecurrentNetwork.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

        # check prediction
        net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)
    finally:
        shutil.rmtree(tmp_path, ignore_errors=True)

    net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)

Here they are using keys like "volume", and this is for data_with_covariates but I am not using the covariate base class that i can use directly this code and modify it to my requirements. I want to understand how this whole thing works and then I can write the test...

@phoeenniixx
Copy link
Author

for now I am just removing the test file and updating the code as required

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor things before a more thorough review:

  • can you kindly add tests for some basic use cases?
  • can you make sure nothing except imports are in the __init__ files? Similar to the recent change sin the repo.

@fkiraly fkiraly added enhancement New feature or request new network labels Jan 5, 2025
@phoeenniixx
Copy link
Author

phoeenniixx commented Jan 6, 2025

Thanks for the review @fkiraly!
Some questions:

  • to create the tests, I think dataloaders_fixed_window_without_covariates (of tests.test_models.conftest.py) should work?
  • Should I use the tests already in the repo ( like test_rnn_model) as a reference and modify it according to xlstm?

@phoeenniixx
Copy link
Author

one more thing, I haven't added the docs for now as I wasn't sure which basemodel would be the best fit and will add those once it is clear.
For now I will add the docs for the m_lstm and s_lstm

@phoeenniixx phoeenniixx requested a review from fkiraly January 22, 2025 01:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request new network
Projects
Development

Successfully merging this pull request may close these issues.

3 participants