You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have recently moved from vanilla PyTorch to Lightning since I like very much how it organizes the code and especially the DataModule. According to the docsDatamodules are for you if you ever asked the questions (emphasis mine):
what splits did you use?
what transforms did you use?
what normalization did you use?
how did you prepare/tokenize the data?
I am interesting in the following scenario. I am given a dataset and I want to perform the following steps:
Split the dataset into train, validation and test.
Train the model and save it (no problem here, I can load any checkpoint).
Come back later (maybe after two days) and test the model on the test set from step (1).
One way to achieve that is to save the indices of the split from step (1). However, I think a more elegant solution is the documented one:
classMNISTDataModule(L.LightningDataModule):
def__init__(self, data_dir: str="path/to/dir", batch_size: int=32):
super().__init__()
self.data_dir=data_dirself.batch_size=batch_sizedefsetup(self, stage: str):
self.mnist_test=MNIST(self.data_dir, train=False)
self.mnist_predict=MNIST(self.data_dir, train=False)
mnist_full=MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val=random_split( ### This is what I want to maintain.mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
Basically, my problem boils down to saving the state of the DataModule. I have read the Save DataModule state, but I can't understand what is saved under the hood and how I am supposed to load it again back, if I want to perform inference.
Please note that I am not interested in theMNISTexample.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I have recently moved from vanilla
PyTorch
toLightning
since I like very much how it organizes the code and especially theDataModule
. According to the docsDatamodules
are for you if you ever asked the questions (emphasis mine):I am interesting in the following scenario. I am given a dataset and I want to perform the following steps:
checkpoint
).One way to achieve that is to save the indices of the split from step (1). However, I think a more elegant solution is the documented one:
Basically, my problem boils down to saving the state of the
DataModule
. I have read the Save DataModule state, but I can't understand what is saved under the hood and how I am supposed to load it again back, if I want to perform inference.Please note that I am not interested in the
MNIST
example.Beta Was this translation helpful? Give feedback.
All reactions