Skip to content

Commit

Permalink
Cast spectral norm output to fprop_dtype.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645449891
  • Loading branch information
The praxis Authors committed Jun 21, 2024
1 parent fee8035 commit d8346c4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
3 changes: 2 additions & 1 deletion praxis/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,4 +669,5 @@ def __call__(

norm = v @ w @ u
wn = w / norm
return jnp.reshape(wn, inputs.shape)
wn = jnp.reshape(wn, inputs.shape)
return self._cast_to_fprop_dtype(wn)
23 changes: 23 additions & 0 deletions praxis/layers/normalizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,29 @@ def test_spectral_norm(self, kernel_size: int, do_eval: bool):
updated[NON_TRAINABLE]['u'], init[NON_TRAINABLE]['u']
)

@parameterized.parameters(
((5, 4, 24, 36), (1, 1), [2, 16, 36, 72], jnp.bfloat16),
((2, 4, 16, 8), (2, 2), [2, 16, 32, 128], jnp.bfloat16),
((4, 8, 16, 32), (1, 1), [2, 16, 32, 64], jnp.float32),
)
def test_spectral_norm_conv_fprop_dtype(
self, filter_shape, filter_stride, input_shape, fprop_dtype
):
inputs = np.random.normal(1.0, 0.5, input_shape).astype('float32')

p = pax_fiddle.Config(
convolutions.Conv2D,
name='jax_conv2d',
filter_shape=filter_shape,
filter_stride=filter_stride,
weight_norm_tpl=pax_fiddle.Config(normalizations.SpectralNorm),
fprop_dtype=fprop_dtype,
)
conv_layer = instantiate(p)
initial_vars = conv_layer.init(jax.random.PRNGKey(seed=123), inputs)
output = conv_layer.apply(initial_vars, inputs)
self.assertEqual(output.dtype, fprop_dtype)


if __name__ == '__main__':
absltest.main()

0 comments on commit d8346c4

Please sign in to comment.