From f660a215a78a69bdad47fd20bca44a08decd6936 Mon Sep 17 00:00:00 2001 From: Chandramouli Shama Sastry Date: Wed, 14 Feb 2024 00:48:15 +0000 Subject: [PATCH] deepspeech modeldiffs --- .../librispeech_pytorch/models.py | 16 +++--- .../__init__.py | 0 .../compare.py | 53 +++++++++++++++++++ .../__init__.py | 0 .../librispeech_deepspeech_normaug/compare.py | 53 +++++++++++++++++++ .../librispeech_deepspeech_tanh/__init__.py | 0 .../librispeech_deepspeech_tanh/compare.py | 53 +++++++++++++++++++ 7 files changed, 167 insertions(+), 8 deletions(-) create mode 100644 tests/modeldiffs/librispeech_deepspeech_noresnet/__init__.py create mode 100644 tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py create mode 100644 tests/modeldiffs/librispeech_deepspeech_normaug/__init__.py create mode 100644 tests/modeldiffs/librispeech_deepspeech_normaug/compare.py create mode 100644 tests/modeldiffs/librispeech_deepspeech_tanh/__init__.py create mode 100644 tests/modeldiffs/librispeech_deepspeech_tanh/compare.py diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index d270df236..a5ee3fa0a 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -200,7 +200,7 @@ def __init__(self, config: DeepspeechConfig): if config.layernorm_everywhere: self.normalization_layer = LayerNorm(config.encoder_dim) else: - self.normalization_layer = BatchNorm( + self.bn_normalization_layer = BatchNorm( dim=config.encoder_dim, batch_norm_momentum=config.batch_norm_momentum, batch_norm_epsilon=config.batch_norm_epsilon) @@ -216,7 +216,7 @@ def forward(self, inputs, input_paddings): if self.config.layernorm_everywhere: inputs = self.normalization_layer(inputs) else: # batchnorm - inputs = self.normalization_layer(inputs, input_paddings) + inputs = self.bn_normalization_layer(inputs, input_paddings) inputs = self.lin(inputs) @@ -288,11 +288,11 @@ def __init__(self, config: DeepspeechConfig): self.bidirectional = bidirectional if config.layernorm_everywhere: - self.normalization_layer = nn.LayerNorm(config.encoder_dim) + self.normalization_layer = LayerNorm(config.encoder_dim) else: - self.normalization_layer = BatchNorm(config.encoder_dim, - config.batch_norm_momentum, - config.batch_norm_epsilon) + self.bn_normalization_layer = BatchNorm(config.encoder_dim, + config.batch_norm_momentum, + config.batch_norm_epsilon) if bidirectional: self.lstm = nn.LSTM( @@ -308,7 +308,7 @@ def forward(self, inputs, input_paddings): if self.config.layernorm_everywhere: inputs = self.normalization_layer(inputs) else: - inputs = self.normalization_layer(inputs, input_paddings) + inputs = self.bn_normalization_layer(inputs, input_paddings) lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy() packed_inputs = torch.nn.utils.rnn.pack_padded_sequence( inputs, lengths, batch_first=True, enforce_sorted=False) @@ -357,7 +357,7 @@ def __init__(self, config: DeepspeechConfig): [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]) if config.enable_decoder_layer_norm: - self.ln = nn.LayerNorm(config.encoder_dim) + self.ln = LayerNorm(config.encoder_dim) else: self.ln = nn.Identity() diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/__init__.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py new file mode 100644 index 000000000..6c00bdf69 --- /dev/null +++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py @@ -0,0 +1,53 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ + LibriSpeechDeepSpeechTanhWorkload as JaxWorkload +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ + LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.librispeech_deepspeech.compare import key_transform +from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PyTorchWorkload() + + # Test outputs for identical weights and inputs. + wave = torch.randn(2, 320000) + pad = torch.zeros_like(wave) + pad[0, 200000:] = 1 + + jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} + pyt_batch = {'inputs': (wave, pad)} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] * + (1 - out_outpad[1][:, :, None])) diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/__init__.py b/tests/modeldiffs/librispeech_deepspeech_normaug/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py new file mode 100644 index 000000000..c68d6adf9 --- /dev/null +++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py @@ -0,0 +1,53 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ + LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ + LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.librispeech_deepspeech.compare import key_transform +from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PyTorchWorkload() + + # Test outputs for identical weights and inputs. + wave = torch.randn(2, 320000) + pad = torch.zeros_like(wave) + pad[0, 200000:] = 1 + + jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} + pyt_batch = {'inputs': (wave, pad)} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] * + (1 - out_outpad[1][:, :, None])) diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/__init__.py b/tests/modeldiffs/librispeech_deepspeech_tanh/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py new file mode 100644 index 000000000..4cfdf4f21 --- /dev/null +++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py @@ -0,0 +1,53 @@ +import os + +# Disable GPU access for both jax and pytorch. +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +import jax +import torch + +from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ + LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload +from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ + LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload +from tests.modeldiffs.diff import out_diff +from tests.modeldiffs.librispeech_deepspeech.compare import key_transform +from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform + +if __name__ == '__main__': + # pylint: disable=locally-disabled, not-callable + + jax_workload = JaxWorkload() + pytorch_workload = PyTorchWorkload() + + # Test outputs for identical weights and inputs. + wave = torch.randn(2, 320000) + pad = torch.zeros_like(wave) + pad[0, 200000:] = 1 + + jax_batch = {'inputs': (wave.detach().numpy(), pad.detach().numpy())} + pyt_batch = {'inputs': (wave, pad)} + + pytorch_model_kwargs = dict( + augmented_and_preprocessed_input_batch=pyt_batch, + model_state=None, + mode=spec.ForwardPassMode.EVAL, + rng=None, + update_batch_norm=False) + + jax_model_kwargs = dict( + augmented_and_preprocessed_input_batch=jax_batch, + mode=spec.ForwardPassMode.EVAL, + rng=jax.random.PRNGKey(0), + update_batch_norm=False) + + out_diff( + jax_workload=jax_workload, + pytorch_workload=pytorch_workload, + jax_model_kwargs=jax_model_kwargs, + pytorch_model_kwargs=pytorch_model_kwargs, + key_transform=key_transform, + sd_transform=sd_transform, + out_transform=lambda out_outpad: out_outpad[0] * + (1 - out_outpad[1][:, :, None]))