-
Notifications
You must be signed in to change notification settings - Fork 649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] xLSTMTime
implementation
#1709
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1709 +/- ##
=======================================
Coverage ? 84.37%
=======================================
Files ? 40
Lines ? 5144
Branches ? 0
=======================================
Hits ? 4340
Misses ? 804
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
hi @fkiraly, I am new to |
Will these tests suffice @fkiraly? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @phoeenniixx,
welcome to pytorch-forecasting and thank you for your pull request and contributing xlstm.
I added first comments about the BaseClass you used. Please change it to one of the BaseClasses (see the comment). Since I suppose that this will change your code a bit. I will wait with a complete review until you changed it.
return trend, seasonal | ||
|
||
|
||
class xLSTMTime(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the Base classes of pytorch-forecasting (BaseModelWithCovariates
, etc.) depending on the properties of the forecaster.
The advantage of doing this is that it automatically comes with PyTorch lightning and thus less boilerplate is needed.
You might compare it with the NHITS implementation and check how it is implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please ensure that the naming conventions of files are met. I.e., only lower case is allowed and use _
as a separator. between words. .../x_lstm_time/x_lstm_time.py
device: Optional[torch.device] = None, | ||
): | ||
""" | ||
Initialize xLSTMTime model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check where to put the reference to the paper that originally proposes xlstm.
Thanks for the review @benHeid! |
Hi @benHeid, I need some help:
Please tell me if I am in a right direction class xLSTMTime(BaseModel):
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: int,
xlstm_type: Literal['slstm', 'mlstm'],
num_layers: int = 1,
decomposition_kernel: int = 25,
input_projection_size: Optional[int] = None,
dropout: float = 0.1,
loss: Metric = SMAPE(),
device: Optional[torch.device] = None,
**kwargs
):
super().__init__(loss=loss, **kwargs)
if xlstm_type not in ['slstm', 'mlstm']:
raise ValueError("xlstm_type must be either 'slstm' or 'mlstm'")
self.xlstm_type = xlstm_type
self._device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(self._device)
self.decomposition = SeriesDecomposition(decomposition_kernel)
self.batch_norm = nn.BatchNorm1d(hidden_size)
self.input_projection_size = input_projection_size or hidden_size
self.input_linear = None
if xlstm_type == 'mlstm':
self.lstm = mLSTMNetwork(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
output_size=hidden_size,
dropout=dropout,
device=self.device
)
else: # slstm
self.lstm = sLSTMNetwork(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
output_size=hidden_size,
dropout=dropout,
device=self.device
)
self.output_linear = nn.Linear(hidden_size, output_size)
self.instance_norm = nn.InstanceNorm1d(output_size)
def forward(
self,
x: Dict[str, torch.Tensor],
hidden_states: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
] = None
) -> Tuple[torch.Tensor, Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
encoder_cont = x["encoder_cont"]
batch_size, seq_len, n_features = encoder_cont.shape
trend, seasonal = self.decomposition(encoder_cont)
x = torch.cat([trend, seasonal], dim=-1)
concatenated_features = x.shape[-1]
if self.input_linear is None:
self.input_linear = nn.Linear(concatenated_features, self.input_projection_size).to(self._device)
x = self.input_linear(x)
x = x.transpose(1, 2)
x = self.batch_norm(x)
x = x.transpose(1, 2)
if hidden_states is None:
hidden_states = self.lstm.init_hidden(batch_size)
x = x.transpose(0, 1)
output, hidden_states = self.lstm(x, *hidden_states)
if isinstance(output, tuple):
output = output[0]
if output.dim() == 2:
output = output.unsqueeze(0)
output = self.output_linear(output)
output = output.transpose(1, 2)
output = self.instance_norm(output)
output = output.transpose(1, 2)
return output, hidden_states
def predict(
self,
x: torch.Tensor,
hidden_states: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
] = None
) -> torch.Tensor:
output, _ = self.forward(x, hidden_states)
return output
def training_step(self, batch, batch_idx):
x, y = batch
y = y[0] if isinstance(y, tuple) else y
y_pred, _ = self(x)
if y_pred.ndim == 3 and y_pred.size(0) == 1:
y_pred = y_pred.squeeze(0)
loss = self.loss(y_pred, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y = y[0] if isinstance(y, tuple) else y
y_pred, _ = self(x)
if y_pred.ndim == 3 and y_pred.size(0) == 1:
y_pred = y_pred.squeeze(0)
loss = self.loss(y_pred, y)
self.log("val_loss", loss)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
y = y[0] if isinstance(y, tuple) else y
y_pred, _ = self(x)
if y_pred.ndim == 3 and y_pred.size(0) == 1:
y_pred = y_pred.squeeze(0)
loss = self.loss(y_pred, y)
self.log("test_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=10)
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} |
Also, Do we need to change the baseclass of just |
Mhm. if the implementation does not support any exogenous features than either
I agree that a dict should be used here.
Yes that is the target time series.
You might check the RNN implementation. Since this is also inheriting from an Autoregressive model and probably the most similar of the implemented models. But I think you are in the right direction. |
Hi @benHeid, I have updated the implementation using I can add the docstrings in subsequent commits once I am sure that this is what we want. |
Sorry for my late response. Please ensure that the linting tests are green. Probably running the pre commit hooks locally should make it. Regarding the failing tests, you might check how the output currently looks like by manually executing the xLSTM. You might then see what the issue is. @fkiraly do we have any guides for pytorch-forecasting on how to write tests? |
Thanks for the reply @benHeid, actually the reason the tests are failing is: earlier I was using tensors, tuple etc and now TimeSeriesDataset is being used that uses a dict, that is the reason the tests are failing, I can correct those but I didn't do that because I noticed that for other models, they just use functions like def _integration(
data_with_covariates, tmp_path, cell_type="LSTM", data_loader_kwargs={}, clip_target: bool = False, **kwargs
):
data_with_covariates = data_with_covariates.copy()
if clip_target:
data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0)
else:
data_with_covariates["target"] = data_with_covariates["volume"]
data_loader_default_kwargs = dict(
target="target",
time_varying_known_reals=["price_actual"],
time_varying_unknown_reals=["target"],
static_categoricals=["agency"],
add_relative_time_idx=True,
)
data_loader_default_kwargs.update(data_loader_kwargs)
dataloaders_with_covariates = make_dataloaders(data_with_covariates, **data_loader_default_kwargs)
train_dataloader = dataloaders_with_covariates["train"]
val_dataloader = dataloaders_with_covariates["val"]
test_dataloader = dataloaders_with_covariates["test"]
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
logger = TensorBoardLogger(tmp_path)
trainer = pl.Trainer(
max_epochs=3,
gradient_clip_val=0.1,
callbacks=[early_stop_callback],
enable_checkpointing=True,
default_root_dir=tmp_path,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
logger=logger,
)
net = RecurrentNetwork.from_dataset(
train_dataloader.dataset,
cell_type=cell_type,
learning_rate=0.15,
log_gradient_flow=True,
log_interval=1000,
hidden_size=5,
**kwargs,
)
net.size()
try:
trainer.fit(
net,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
test_outputs = trainer.test(net, dataloaders=test_dataloader)
assert len(test_outputs) > 0
# check loading
net = RecurrentNetwork.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
# check prediction
net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)
finally:
shutil.rmtree(tmp_path, ignore_errors=True)
net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) Here they are using keys like "volume", and this is for |
for now I am just removing the test file and updating the code as required |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor things before a more thorough review:
- can you kindly add tests for some basic use cases?
- can you make sure nothing except imports are in the
__init__
files? Similar to the recent change sin the repo.
Thanks for the review @fkiraly!
|
one more thing, I haven't added the docs for now as I wasn't sure which |
Description
This PR tries to implement xLSTMTime based on this paper
see also
sktime
issue #6793Checklist
pre-commit install
.To run hooks independent of commit, execute
pre-commit run --all-files