From 260ff260ed16bf3f6df7744907595729cee8c6b1 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 23 Jan 2025 15:55:47 +0100 Subject: [PATCH] Bug fix PR #423 --- pina/data/data_module.py | 42 ++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 4b529fe2..ea3cbb46 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -8,6 +8,7 @@ from torch.utils.data.distributed import DistributedSampler from .dataset import PinaDatasetFactory + class DummyDataloader: def __init__(self, dataset, device): self.dataset = dataset.get_all_data() @@ -21,6 +22,7 @@ def __len__(self): def __next__(self): return self.dataset + class Collator: def __init__(self, max_conditions_lengths, ): self.max_conditions_lengths = max_conditions_lengths @@ -48,7 +50,7 @@ def _collate_standard_dataloader(self, batch): for arg in condition_args: data_list = [batch[idx][condition_name][arg] for idx in range( min(len(batch), - self.max_conditions_lengths[condition_name]))] + self.max_conditions_lengths[condition_name]))] if isinstance(data_list[0], LabelTensor): single_cond_dict[arg] = LabelTensor.stack(data_list) elif isinstance(data_list[0], torch.Tensor): @@ -80,6 +82,7 @@ def __init__(self, dataset, batch_size, shuffle, sampler=None): super().__init__(sampler=sampler, batch_size=batch_size, drop_last=False) + class PinaDataModule(LightningDataModule): """ This class extend LightningDataModule, allowing proper creation and @@ -136,6 +139,7 @@ def __init__(self, else: self.predict_dataloader = super().predict_dataloader self.collector_splits = self._create_splits(collector, splits_dict) + self.transfer_batch_to_device = self._transfer_batch_to_device def setup(self, stage=None): """ @@ -151,7 +155,7 @@ def setup(self, stage=None): self.val_dataset = PinaDatasetFactory( self.collector_splits['val'], max_conditions_lengths=self.find_max_conditions_lengths( - 'val'), automatic_batching=self.automatic_batching + 'val'), automatic_batching=self.automatic_batching ) elif stage == 'test': self.test_dataset = PinaDatasetFactory( @@ -215,6 +219,7 @@ def _apply_shuffle(condition_dict, len_data): condition_dict[k] = v[idx] else: raise ValueError(f"Data type {type(v)} not supported") + # ----------- End auxiliary function ------------ logging.debug('Dataset creation in PinaDataModule obj') @@ -251,14 +256,19 @@ def val_dataloader(self): if self.automatic_batching: collate = Collator(self.find_max_conditions_lengths('val')) return DataLoader(self.val_dataset, self.batch_size, - collate_fn=collate) + collate_fn=collate) collate = Collator(None) - sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False) + sampler = PinaBatchSampler(self.val_dataset, self.batch_size, + shuffle=False) return DataLoader(self.val_dataset, sampler=sampler, - collate_fn=collate) - dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) - dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) - self.transfer_batch_to_device = self.dummy_transfer_to_device + collate_fn=collate) + dataloader = DummyDataloader(self.val_dataset, + self.trainer.strategy.root_device) + dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset, + self.trainer.strategy.root_device, + 0) + self.transfer_batch_to_device = self._transfer_batch_to_device_dummy + return dataloader def train_dataloader(self): """ @@ -273,12 +283,15 @@ def train_dataloader(self): collate_fn=collate) collate = Collator(None) sampler = PinaBatchSampler(self.train_dataset, self.batch_size, - shuffle=False) + shuffle=False) return DataLoader(self.train_dataset, sampler=sampler, collate_fn=collate) - dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device) - dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0) - self.transfer_batch_to_device = self.dummy_transfer_to_device + dataloader = DummyDataloader(self.train_dataset, + self.trainer.strategy.root_device) + dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset, + self.trainer.strategy.root_device, + 0) + self.transfer_batch_to_device = self._transfer_batch_to_device_dummy return dataloader def test_dataloader(self): @@ -293,10 +306,10 @@ def predict_dataloader(self): """ raise NotImplementedError("Predict dataloader not implemented") - def dummy_transfer_to_device(self, batch, device, dataloader_idx): + def _transfer_batch_to_device_dummy(self, batch, device, dataloader_idx): return batch - def transfer_batch_to_device(self, batch, device, dataloader_idx): + def _transfer_batch_to_device(self, batch, device, dataloader_idx): """ Transfer the batch to the device. This method is called in the training loop and is used to transfer the batch to the device. @@ -307,4 +320,5 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): dataloader_idx)) for k, v in batch.items() ] + return batch