Skip to content

Commit

Permalink
Allow use of dataloader when calling TrainingPipeline (#100)
Browse files Browse the repository at this point in the history
* Allow use of dataloader when calling TrainingPipeline

* Small changes to #100

---------

Co-authored-by: clementchadebec <[email protected]>
  • Loading branch information
paul-english and clementchadebec authored Sep 14, 2023
1 parent ce4e562 commit 0ab36ba
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 45 deletions.
73 changes: 41 additions & 32 deletions src/pythae/pipelines/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(
if training_config is None:
if model.model_name == "RAE_L2":
training_config = CoupledOptimizerTrainerConfig(
encoder_optim_decay=0,
decoder_optim_decay=model.model_config.reg_weight,
encoder_optimizer_params={"weight_decay": 0},
decoder_optimizer_params={"weight_decay": model.model_config.reg_weight},
)

elif (
Expand Down Expand Up @@ -153,59 +153,69 @@ def _check_dataset(self, dataset: BaseDataset):

def __call__(
self,
train_data: Union[np.ndarray, torch.Tensor, torch.utils.data.Dataset],
eval_data: Union[np.ndarray, torch.Tensor, torch.utils.data.Dataset] = None,
train_data: Union[
np.ndarray,
torch.Tensor,
torch.utils.data.Dataset,
torch.utils.data.DataLoader,
] = None,
eval_data: Union[
np.ndarray,
torch.Tensor,
torch.utils.data.Dataset,
torch.utils.data.DataLoader,
] = None,
callbacks: List[TrainingCallback] = None,
):
"""
Launch the model training on the provided data.
Args:
training_data (Union[~numpy.ndarray, ~torch.Tensor]): The training data as a
:class:`numpy.ndarray` or :class:`torch.Tensor` of shape (mini_batch x
n_channels x ...)
train_data: The training data or DataLoader.
eval_data (Optional[Union[~numpy.ndarray, ~torch.Tensor]]): The evaluation data as a
:class:`numpy.ndarray` or :class:`torch.Tensor` of shape (mini_batch x
n_channels x ...). If None, only uses train_fata for training. Default: None.
eval_data: The evaluation data or DataLoader. If None, only uses train_data for training. Default: None
callbacks (List[~pythae.trainers.training_callbacks.TrainingCallbacks]):
A list of callbacks to use during training.
"""

if isinstance(train_data, np.ndarray) or isinstance(train_data, torch.Tensor):
# Initialize variables for datasets and dataloaders
train_dataset, eval_dataset = None, None
train_dataloader, eval_dataloader = None, None

if isinstance(train_data, torch.utils.data.DataLoader):
train_dataloader = train_data
elif isinstance(train_data, (np.ndarray, torch.Tensor)):
logger.info("Preprocessing train data...")
train_data = self.data_processor.process_data(train_data)
train_dataset = self.data_processor.to_dataset(train_data)

logger.info("Checking train dataset...")
self._check_dataset(train_dataset)
else:
train_dataset = train_data

logger.info("Checking train dataset...")
self._check_dataset(train_dataset)
logger.info("Checking train dataset...")
self._check_dataset(train_dataset)

if eval_data is not None:
if isinstance(eval_data, np.ndarray) or isinstance(eval_data, torch.Tensor):
if isinstance(eval_data, torch.utils.data.DataLoader):
eval_dataloader = eval_data
elif isinstance(eval_data, (np.ndarray, torch.Tensor)):
logger.info("Preprocessing eval data...\n")
eval_data = self.data_processor.process_data(eval_data)
eval_dataset = self.data_processor.to_dataset(eval_data)

logger.info("Checking eval dataset...")
self._check_dataset(eval_dataset)
else:
eval_dataset = eval_data

logger.info("Checking eval dataset...")
self._check_dataset(eval_dataset)

else:
eval_dataset = None
logger.info("Checking eval dataset...")
self._check_dataset(eval_dataset)

if isinstance(self.training_config, CoupledOptimizerTrainerConfig):
logger.info("Using Coupled Optimizer Trainer\n")
trainer = CoupledOptimizerTrainer(
model=self.model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset=train_dataloader or train_dataset,
eval_dataset=eval_dataloader or eval_dataset,
training_config=self.training_config,
callbacks=callbacks,
)
Expand All @@ -214,8 +224,8 @@ def __call__(
logger.info("Using Adversarial Trainer\n")
trainer = AdversarialTrainer(
model=self.model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset=train_dataloader or train_dataset,
eval_dataset=eval_dataloader or eval_dataset,
training_config=self.training_config,
callbacks=callbacks,
)
Expand All @@ -224,8 +234,8 @@ def __call__(
logger.info("Using Coupled Optimizer Adversarial Trainer\n")
trainer = CoupledOptimizerAdversarialTrainer(
model=self.model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset=train_dataloader or train_dataset,
eval_dataset=eval_dataloader or eval_dataset,
training_config=self.training_config,
callbacks=callbacks,
)
Expand All @@ -234,14 +244,13 @@ def __call__(
logger.info("Using Base Trainer\n")
trainer = BaseTrainer(
model=self.model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset=train_dataloader or train_dataset,
eval_dataset=eval_dataloader or eval_dataset,
training_config=self.training_config,
callbacks=callbacks,
)
else:
raise ValueError("The provided training config is not supported.")

self.trainer = trainer

trainer.train()
27 changes: 20 additions & 7 deletions src/pythae/trainers/base_trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import logging
import os
from copy import deepcopy
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed as dist
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

from ...customexception import ModelError
Expand Down Expand Up @@ -56,8 +56,8 @@ class BaseTrainer:
def __init__(
self,
model: BaseAE,
train_dataset: BaseDataset,
eval_dataset: Optional[BaseDataset] = None,
train_dataset: Union[BaseDataset, DataLoader],
eval_dataset: Optional[Union[BaseDataset, DataLoader]] = None,
training_config: Optional[BaseTrainerConfig] = None,
callbacks: List[TrainingCallback] = None,
):
Expand Down Expand Up @@ -119,11 +119,24 @@ def __init__(
self.eval_dataset = eval_dataset

# Define the loaders
train_loader = self.get_train_dataloader(train_dataset)
if isinstance(train_dataset, DataLoader):
train_loader = train_dataset
logger.warn(
"Using the provided train dataloader! Carefull this may overwrite some "
"parameters provided in your training config."
)
else:
train_loader = self.get_train_dataloader(train_dataset)

if eval_dataset is not None:
eval_loader = self.get_eval_dataloader(eval_dataset)

if isinstance(eval_dataset, DataLoader):
eval_loader = eval_dataset
logger.warn(
"Using the provided eval dataloader! Carefull this may overwrite some "
"parameters provided in your training config."
)
else:
eval_loader = self.get_eval_dataloader(eval_dataset)
else:
logger.info(
"! No eval dataset provided ! -> keeping best model on train.\n"
Expand Down
45 changes: 39 additions & 6 deletions tests/test_pipeline_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import pytest
import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader

from pythae.customexception import DatasetError
from pythae.data.datasets import DatasetOutput
from pythae.models import VAE, FactorVAE, FactorVAEConfig, VAEConfig
from pythae.models import VAE, VAEConfig, Adversarial_AE, Adversarial_AE_Config, RAE_L2, RAE_L2_Config, VAEGAN, VAEGANConfig
from pythae.pipelines import *
from pythae.samplers import NormalSampler, NormalSamplerConfig
from pythae.trainers import BaseTrainerConfig
Expand Down Expand Up @@ -69,22 +69,39 @@ def custom_no_len_train_dataset(self):
return CustomWrongOutputDataset(
os.path.join(PATH, "data/mnist_clean_train_dataset_sample")
)

@pytest.fixture(
params=[
(VAE, VAEConfig),
(Adversarial_AE, Adversarial_AE_Config),
(RAE_L2, RAE_L2_Config),
(VAEGAN, VAEGANConfig)
]
)
def model(self, request, train_dataset):
model = request.param[0](request.param[1](input_dim=tuple(train_dataset.data[0].shape), latent_dim=2))

return model

@pytest.fixture
def training_pipeline(self, train_dataset):
def train_dataloader(self, custom_train_dataset):
return DataLoader(dataset=custom_train_dataset, batch_size=32)

@pytest.fixture
def training_pipeline(self, model, train_dataset):
vae_config = VAEConfig(
input_dim=tuple(train_dataset.data[0].shape), latent_dim=2
)
vae = VAE(vae_config)
pipe = TrainingPipeline(model=vae)
pipe = TrainingPipeline(model=model)
return pipe

def test_base_pipeline(self):
with pytest.raises(NotImplementedError):
pipe = Pipeline()
pipe()

def test_training_pipeline(self, tmpdir, training_pipeline, train_dataset):
def test_training_pipeline(self, tmpdir, training_pipeline, train_dataset, model):

with pytest.raises(AssertionError):
pipeline = TrainingPipeline(
Expand All @@ -96,7 +113,10 @@ def test_training_pipeline(self, tmpdir, training_pipeline, train_dataset):
training_pipeline.training_config.output_dir = dir_path
training_pipeline.training_config.num_epochs = 1
training_pipeline(train_dataset.data)
assert isinstance(training_pipeline.model, VAE)
assert isinstance(training_pipeline.model, model.__class__)

if model.__class__ == RAE_L2:
assert training_pipeline.trainer.decoder_optimizer.state_dict()['param_groups'][0]['weight_decay'] == model.model_config.reg_weight

def test_training_pipeline_wrong_output_dataset(
self,
Expand Down Expand Up @@ -179,3 +199,16 @@ def test_generation_pipeline(self, tmpdir, train_dataset):
assert tuple(gen_data.shape) == (1,) + (1, 2, 3)
assert len(os.listdir(dir_path)) == 1 + 1
assert "sampler_config.json" in os.listdir(dir_path)

def test_training_pipeline_with_dataloader(
self, tmpdir, training_pipeline, train_dataloader
):
# Simulate a training run with a DataLoader
tmpdir.mkdir("dataloader_test")
dir_path = os.path.join(tmpdir, "dataloader_test")
training_pipeline.training_config.output_dir = dir_path
training_pipeline.training_config.num_epochs = 1
training_pipeline(train_data=train_dataloader, eval_data=train_dataloader)

assert isinstance(training_pipeline.trainer.train_loader, DataLoader)
assert isinstance(training_pipeline.trainer.eval_loader, DataLoader)

0 comments on commit 0ab36ba

Please sign in to comment.