Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoOlivo committed Jan 25, 2025
1 parent 7dac9c7 commit 290641a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 53 deletions.
69 changes: 33 additions & 36 deletions pina/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
RandomSampler
from torch.utils.data.distributed import DistributedSampler
from .dataset import PinaDatasetFactory
from torch_geometric.data import Data, Batch


class DummyDataloader:
Expand All @@ -24,15 +25,15 @@ def __next__(self):


class Collator:
def __init__(self, max_conditions_lengths, ):
def __init__(self, max_conditions_lengths, dataset=None):
self.max_conditions_lengths = max_conditions_lengths
self.callable_function = self._collate_custom_dataloader if \
max_conditions_lengths is None else (
self._collate_standard_dataloader)
self.dataset = dataset

@staticmethod
def _collate_custom_dataloader(batch):
return batch[0]
def _collate_custom_dataloader(self, batch):
return self.dataset.__getitem_list__(batch)

def _collate_standard_dataloader(self, batch):
"""
Expand All @@ -50,37 +51,36 @@ def _collate_standard_dataloader(self, batch):
for arg in condition_args:
data_list = [batch[idx][condition_name][arg] for idx in range(
min(len(batch),
self.max_conditions_lengths[condition_name]))]
self.max_conditions_lengths[condition_name]))]
if isinstance(data_list[0], LabelTensor):
single_cond_dict[arg] = LabelTensor.stack(data_list)
elif isinstance(data_list[0], torch.Tensor):
single_cond_dict[arg] = torch.stack(data_list)
elif isinstance(data_list[0], Data):
single_cond_dict[arg] = Batch.from_data_list(data_list)
else:
raise NotImplementedError(
f"Data type {type(data_list[0])} not supported")
batch_dict[condition_name] = single_cond_dict
return batch_dict


def __call__(self, batch):
return self.callable_function(batch)


class PinaBatchSampler(BatchSampler):
def __init__(self, dataset, batch_size, shuffle, sampler=None):
if sampler is None:
if (torch.distributed.is_available() and
torch.distributed.is_initialized()):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
sampler = DistributedSampler(dataset, shuffle=shuffle,
rank=rank, num_replicas=world_size)
class PinaSampler:
def __new__(self, dataset, batch_size, shuffle, automatic_batching):

if (torch.distributed.is_available() and
torch.distributed.is_initialized()):
sampler = DistributedSampler(dataset, shuffle=shuffle)
else:
if shuffle:
sampler = RandomSampler(dataset)
else:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
super().__init__(sampler=sampler, batch_size=batch_size,
drop_last=False)
sampler = SequentialSampler(dataset)
return sampler


class PinaDataModule(LightningDataModule):
Expand Down Expand Up @@ -252,16 +252,14 @@ def val_dataloader(self):
"""
# Use custom batching (good if batch size is large)
if self.batch_size is not None:
# Use default batching in torch DataLoader (good is batch size is small)
sampler = PinaSampler(self.val_dataset, self.batch_size,
self.shuffle, self.automatic_batching)
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('val'))
return DataLoader(self.val_dataset, self.batch_size,
collate_fn=collate)
collate = Collator(None)
sampler = PinaBatchSampler(self.val_dataset, self.batch_size,
shuffle=False)
return DataLoader(self.val_dataset, sampler=sampler,
collate_fn=collate)
else:
collate = Collator(None, self.val_dataset)
return DataLoader(self.val_dataset, self.batch_size,
collate_fn=collate, sampler=sampler)
dataloader = DummyDataloader(self.val_dataset,
self.trainer.strategy.root_device)
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
Expand All @@ -276,16 +274,15 @@ def train_dataloader(self):
"""
# Use custom batching (good if batch size is large)
if self.batch_size is not None:
# Use default batching in torch DataLoader (good is batch size is small)
sampler = PinaSampler(self.train_dataset, self.batch_size,
self.shuffle, self.automatic_batching)
if self.automatic_batching:
collate = Collator(self.find_max_conditions_lengths('train'))
return DataLoader(self.train_dataset, self.batch_size,
collate_fn=collate)
collate = Collator(None)
sampler = PinaBatchSampler(self.train_dataset, self.batch_size,
shuffle=False)
return DataLoader(self.train_dataset, sampler=sampler,
collate_fn=collate)

else:
collate = Collator(None, self.train_dataset)
return DataLoader(self.train_dataset, self.batch_size,
collate_fn=collate, sampler=sampler)
dataloader = DummyDataloader(self.train_dataset,
self.trainer.strategy.root_device)
dataloader.dataset = self._transfer_batch_to_device(dataloader.dataset,
Expand Down
45 changes: 30 additions & 15 deletions pina/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _getitem_int(self, idx):
in v.keys()} for k, v in self.conditions_dict.items()
}

def _getitem_list(self, idx):
def __getitem_list__(self, idx):
to_return_dict = {}
for condition, data in self.conditions_dict.items():
cond_idx = idx[:self.max_conditions_lengths[condition]]
Expand All @@ -75,34 +75,49 @@ def _getitem_list(self, idx):
for k, v in data.items()}
return to_return_dict

@staticmethod
def _getitem_list(idx):
return idx

def get_all_data(self):
index = [i for i in range(len(self))]
return self._getitem_list(index)
return self.__getitem_list__(index)

def __getitem__(self, idx):
return self._getitem_func(idx)

class PinaGraphDataset(PinaDataset):
pass
"""
def __init__(self, conditions_dict, max_conditions_lengths):
def __init__(self, conditions_dict, max_conditions_lengths,
automatic_batching):
super().__init__(conditions_dict, max_conditions_lengths)
if automatic_batching:
self._getitem_func = self._getitem_int
else:
self._getitem_func = self._getitem_list

def __getitem__(self, idx):
Getitem method for large batch size
def __getitem_list__(self, idx):
to_return_dict = {}
for condition, data in self.conditions_dict.items():
cond_idx = idx[:self.max_conditions_lengths[condition]]
condition_len = self.conditions_length[condition]
if self.length > condition_len:
cond_idx = [idx%condition_len for idx in cond_idx]
to_return_dict[condition] = {k: Batch.from_data_list([v[i]
for i in cond_idx])
if isinstance(v, list)
else v[cond_idx].tensor.reshape(-1, v.size(-1))
for k, v in data.items()
}
for i in cond_idx])
if isinstance(v, list)
else v[cond_idx]
for k, v in data.items()
}
return to_return_dict
"""

def _getitem_list(self, idx):
return idx

def _getitem_int(self, idx):
return {
k: {k_data: v[k_data][idx % len(v['input_points'])] for k_data
in v.keys()} for k, v in self.conditions_dict.items()
}

def __getitem__(self, idx):
return self._getitem_func(idx)
1 change: 0 additions & 1 deletion pina/solvers/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __init__(self,
problem=problem,
optimizers=optimizer,
schedulers=scheduler,
extra_features=extra_features,
use_lt=use_lt)

# check consistency
Expand Down
4 changes: 3 additions & 1 deletion pina/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ def _create_loader(self):
raise RuntimeError('Cannot create Trainer if not all conditions '
'are sampled. The Trainer got the following:\n'
f'{error_message}')
automatic_batching = False
self.data_module = PinaDataModule(collector=self.solver.problem.collector,
train_size=self.train_size,
test_size=self.test_size,
val_size=self.val_size,
predict_size=self.predict_size,
batch_size=self.batch_size,)
batch_size=self.batch_size,
automatic_batching=automatic_batching)

def train(self, **kwargs):
"""
Expand Down

0 comments on commit 290641a

Please sign in to comment.