Skip to content

Commit

Permalink
Bug fix PR #423
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoOlivo committed Jan 27, 2025
1 parent 6a1860b commit 260ff26
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory


class DummyDataloader:
def __init__(self, dataset, device):
self.dataset = dataset.get_all_data()
Expand All @@ -21,6 +22,7 @@ def __len__(self):
def __next__(self):
return self.dataset


class Collator:
def __init__(self, max_conditions_lengths, ):
self.max_conditions_lengths = max_conditions_lengths
Expand Down Expand Up @@ -48,7 +50,7 @@ def _collate_standard_dataloader(self, batch):
for arg in condition_args:
data_list = [batch[idx][condition_name][arg] for idx in range(
min(len(batch),
self.max_conditions_lengths[condition_name]))]
self.max_conditions_lengths[condition_name]))]
if isinstance(data_list[0], LabelTensor):
single_cond_dict[arg] = LabelTensor.stack(data_list)
elif isinstance(data_list[0], torch.Tensor):
Expand Down Expand Up @@ -80,6 +82,7 @@ def __init__(self, dataset, batch_size, shuffle, sampler=None):
super().__init__(sampler=sampler, batch_size=batch_size,
drop_last=False)


class PinaDataModule(LightningDataModule):
"""
This class extend LightningDataModule, allowing proper creation and
Expand Down Expand Up @@ -136,6 +139,7 @@ def __init__(self,
else:
self.predict_dataloader = super().predict_dataloader
self.collector_splits = self._create_splits(collector, splits_dict)
self.transfer_batch_to_device = self._transfer_batch_to_device

def setup(self, stage=None):
"""
Expand All @@ -151,7 +155,7 @@ def setup(self, stage=None):
self.val_dataset = PinaDatasetFactory(
self.collector_splits['val'],
max_conditions_lengths=self.find_max_conditions_lengths(
'val'), automatic_batching=self.automatic_batching
'val'), automatic_batching=self.automatic_batching
)
elif stage == 'test':
self.test_dataset = PinaDatasetFactory(
Expand Down Expand Up @@ -215,6 +219,7 @@ def _apply_shuffle(condition_dict, len_data):
condition_dict[k] = v[idx]
else:
raise ValueError(f"Data type {type(v)} not supported")

# ----------- End auxiliary function ------------

logging.debug('Dataset creation in PinaDataModule obj')
Expand Down Expand Up @@ -251,14 +256,19 @@ def val_dataloader(self):
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('val'))
return DataLoader(self.val_dataset, self.batch_size,
collate_fn=collate)
collate_fn=collate)
collate = Collator(None)
sampler = PinaBatchSampler(self.val_dataset, self.batch_size, shuffle=False)
sampler = PinaBatchSampler(self.val_dataset, self.batch_size,
shuffle=False)
return DataLoader(self.val_dataset, sampler=sampler,
collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
collate_fn=collate)
dataloader = DummyDataloader(self.val_dataset,
self.trainer.strategy.root_device)
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
self.trainer.strategy.root_device,
0)
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dataloader

def train_dataloader(self):
"""
Expand All @@ -273,12 +283,15 @@ def train_dataloader(self):
collate_fn=collate)
collate = Collator(None)
sampler = PinaBatchSampler(self.train_dataset, self.batch_size,
shuffle=False)
shuffle=False)
return DataLoader(self.train_dataset, sampler=sampler,
collate_fn=collate)
dataloader = DummyDataloader(self.train_dataset, self.trainer.strategy.root_device)
dataloader.dataset = self.transfer_batch_to_device(dataloader.dataset, self.trainer.strategy.root_device, 0)
self.transfer_batch_to_device = self.dummy_transfer_to_device
dataloader = DummyDataloader(self.train_dataset,
self.trainer.strategy.root_device)
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
self.trainer.strategy.root_device,
0)
self.transfer_batch_to_device = self._transfer_batch_to_device_dummy
return dataloader

def test_dataloader(self):
Expand All @@ -293,10 +306,10 @@ def predict_dataloader(self):
"""
raise NotImplementedError("Predict dataloader not implemented")

def dummy_transfer_to_device(self, batch, device, dataloader_idx):
def _transfer_batch_to_device_dummy(self, batch, device, dataloader_idx):
return batch

def transfer_batch_to_device(self, batch, device, dataloader_idx):
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
"""
Transfer the batch to the device. This method is called in the
training loop and is used to transfer the batch to the device.
Expand All @@ -307,4 +320,5 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx):
dataloader_idx))
for k, v in batch.items()
]

return batch

0 comments on commit 260ff26

Please sign in to comment.