diff --git a/deel/torchlip/modules/module.py b/deel/torchlip/modules/module.py index a951a61..43faebb 100644 --- a/deel/torchlip/modules/module.py +++ b/deel/torchlip/modules/module.py @@ -72,13 +72,13 @@ def vanilla_model(model: nn.Module): model (nn.Module): Lipschitz neural network """ for n, module in model.named_children(): - if len(list(module.children())) > 0: - # compound module, go inside it - vanilla_model(module) - if isinstance(module, LipschitzModule): # simple module setattr(model, n, module.vanilla_export()) + elif len(list(module.children())) > 0: + # compound module, go inside it + vanilla_model(module) + class _LipschitzCoefMultiplication(nn.Module):