Release 0.0.8
New Features:
- Added
MLFlowCallback
inTrainingCalbacks
further to #44 - Allow custom
Dataset
inheriting fromtorch.utils.data.Dataset
to be passed as inputs in thetraining_pipeline
further to #35
def __call__(
self,
train_data: Union[np.ndarray, torch.Tensor, torch.utils.data.Dataset],
eval_data: Union[np.ndarray, torch.Tensor, torch.utils.data.Dataset] = None,
callbacks: List[TrainingCallback] = None,
):
- Added implementation of Multiply/Partially/Combination IWAE
MIWAE
,PIWAE
andCIWAE
(https://arxiv.org/abs/1802.04537)
Minor changes
- Unify data handling in
FactorVAE
with other models. (half of the batch is used for reconstruction and the other one for factorial representation) - Change model sanity check method in
trainers
(use loaders in check instead of datasets) - Add encoder/decoder losses needed in
CoupledOptimizerTrainer
and update tests