Skip to content

Commit

Permalink
Add vqgan losses for video models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668237503
  • Loading branch information
bignamehyp authored and pax authors committed Aug 28, 2024
1 parent cd7d76d commit 2080c2a
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 0 deletions.
35 changes: 35 additions & 0 deletions praxis/layers/video/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,41 @@ pytype_strict_test(
],
)

pytype_strict_library(
name = "losses",
srcs = [
"losses.py",
],
srcs_version = "PY3",
deps = [
# Implicit jax dependency.
"//praxis:base_layer",
"//praxis:base_model",
"//praxis:py_utils",
"//praxis:pytypes",
],
)

pytype_strict_test(
name = "losses_test",
srcs = [
"losses_test.py",
],
srcs_version = "PY3",
deps = [
":losses",
":vqvae",
# Implicit absl.testing.absltest dependency.
# Implicit absl.testing.parameterized dependency.
# Implicit jax dependency.
# Implicit numpy dependency.
"//praxis:base_layer",
"//praxis:pax_fiddle",
"//praxis:py_utils",
"//praxis:test_utils",
],
)

pytype_strict_library(
name = "quantizer",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions praxis/layers/video/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
"""Exposes the public layer functionalities for video."""

from praxis.layers.video import enc_dec_3dcnn
from praxis.layers.video import losses
from praxis.layers.video import quantizer
from praxis.layers.video import vqvae
155 changes: 155 additions & 0 deletions praxis/layers/video/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# coding=utf-8
# Copyright 2022 The Pax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Loss functions for vqvae/vqgan models."""

from collections.abc import Callable
import jax
import jax.numpy as jnp
from praxis import base_layer
from praxis import base_model
from praxis import py_utils
from praxis import pytypes

JTensor = pytypes.JTensor


def r1_gradient_penalty(
inputs: JTensor,
logits_fn: Callable[[JTensor], JTensor],
grad_penalty_cost: float = 10.0,
) -> tuple[JTensor, JTensor]:
"""Calculates gradients penalty loss to regularize the discriminator.
Args:
inputs: A tensor of image inputs.
logits_fn: A function that takes inputs and returns logits.
grad_penalty_cost: scalar weight for the gradient penalty loss.
Returns:
A tuple of logits and the gradient penalty.
"""
out, vjp_fn = jax.vjp(logits_fn, inputs, has_aux=False)
# Check if jax.value_and_grad is more efficient than jax.vjp at scale.
grad = vjp_fn(jnp.ones_like(out))[0]
flattened_grad = jnp.asarray(grad.reshape((inputs.shape[0], -1)), jnp.float32)
penalty = (
jnp.mean(jnp.sum(jnp.square(flattened_grad), axis=-1)) * grad_penalty_cost
)
return out, penalty


def _discriminator_loss(logits_real: JTensor, logits_fake: JTensor) -> JTensor:
"""Calculates non-saturating discriminator loss."""
d_loss_real = jax.nn.softplus(-logits_real)
d_loss_fake = jax.nn.softplus(logits_fake)
return jnp.mean(d_loss_real) + jnp.mean(d_loss_fake)


def _generator_loss(logits_fake):
"""Calculates non-saturating generator loss."""
return jnp.mean(jax.nn.softplus(-logits_fake))


class VQGANLoss(base_layer.BaseLayer):
"""Loss layer for VQGAN."""

g_adversarial_loss_weight: float = 0.1
reconstruction_loss_weight: float = 5.0
polyak_decay: float = 0.999
lecam_weight: float = 0.001

def lecam_loss(self, real_pred: JTensor, fake_pred: JTensor) -> JTensor:
"""Calculates lecam loss.
Described in https://arxiv.org/abs/2104.03310
Args:
real_pred: scalar, predictions for the real samples.
fake_pred: scalar, prdictions for the reconstructed (fake) samples.
Returns:
Lecam regularization loss (scalar).
"""
ema_fake_pred = self.get_var('ema_fake_pred')
ema_real_pred = self.get_var('ema_real_pred')
return jnp.mean(
jnp.power(jax.nn.relu(real_pred - ema_fake_pred), 2)
) + jnp.mean(jnp.power(jax.nn.relu(ema_real_pred - fake_pred), 2))

def setup(self):
"""Constructs this jax module and registers variables."""
decay_factor_hparams = base_layer.WeightHParams(
shape=[],
init=base_layer.WeightInit.Constant(0.0),
dtype=jnp.float32,
collections=[base_layer.WeightHParamsCollection.REQUIRES_MEAN_SYNC],
)

self.create_variable('ema_real_pred', decay_factor_hparams, trainable=False)
self.create_variable('ema_fake_pred', decay_factor_hparams, trainable=False)

def __call__(
self, predictions: base_model.Predictions, input_batch: py_utils.NestedMap
) -> py_utils.NestedMap:
original_video = input_batch.video
reconstructed = predictions['reconstructed']
logits_real = predictions['logits_real']
logits_fake = predictions['logits_fake']
real_pred = jnp.mean(logits_real)
fake_pred = jnp.mean(logits_fake)

ema_fake_pred = self.get_var('ema_fake_pred')
ema_real_pred = self.get_var('ema_real_pred')
ema_fake_pred = (
fake_pred * (1 - self.polyak_decay) + ema_fake_pred * self.polyak_decay
)
ema_real_pred = (
real_pred * (1 - self.polyak_decay) + ema_real_pred * self.polyak_decay
)
self.update_var('ema_fake_pred', ema_fake_pred)
self.update_var('ema_real_pred', ema_real_pred)

losses = py_utils.NestedMap()
losses.grad_penalty = predictions['r1_gradient_penalty']
losses.lecam_loss = (
self.lecam_loss(logits_real, logits_fake) * self.lecam_weight
)

losses.d_adversarial_loss = _discriminator_loss(logits_real, logits_fake)
losses.g_adversarial_loss = (
_generator_loss(logits_fake) * self.g_adversarial_loss_weight
)

diff = jnp.asarray(original_video - reconstructed, jnp.float32)

losses.reconstruction_loss = (
jnp.mean(jnp.square(diff)) * self.reconstruction_loss_weight
)
losses.perceptual_loss = jnp.array(0.0, dtype=jnp.float32)
if self.do_eval:
losses.quantizer_loss = jnp.zeros_like(losses.reconstruction_loss)
else:
losses.quantizer_loss = predictions['quantizer_loss']
losses.d_loss = (
losses.d_adversarial_loss + losses.grad_penalty + losses.lecam_loss
)
losses.g_loss = (
losses.reconstruction_loss
+ losses.g_adversarial_loss
+ losses.perceptual_loss
+ losses.quantizer_loss
)
return losses
93 changes: 93 additions & 0 deletions praxis/layers/video/losses_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# coding=utf-8
# Copyright 2022 The Pax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools

from absl.testing import absltest
from absl.testing import parameterized
import jax
import numpy as np
from praxis import base_layer
from praxis import pax_fiddle
from praxis import py_utils
from praxis import test_utils
from praxis.layers.video import losses
from praxis.layers.video import vqvae


class LossesTest(test_utils.TestCase):

def test_r1_gradient_penalty(self):
prng_key = jax.random.PRNGKey(seed=123)
x = jax.random.normal(prng_key, (2, 5, 16, 16, 3))
# Create a pax layer and get the output from the random input.
p = pax_fiddle.Config(
vqvae.Discriminator,
name='magvit',
num_frames=5,
image_height=16,
image_width=16,
filters=32,
channel_multipliers=(2, 4),
)
context_p = base_layer.JaxContext.HParams(do_eval=False)
with base_layer.JaxContext.new_context(hparams=context_p):
pax_layer = base_layer.instantiate(p)
pax_vars = pax_layer.init(prng_key, x)
logit_fn = functools.partial(pax_layer.apply, pax_vars)
logits, penalty = losses.r1_gradient_penalty(x, logit_fn)
self.assertEqual(logits.shape, (2, 1))
self.assertEqual(penalty.shape, ())

@parameterized.parameters(True, False)
def test_vqgan_loss(self, do_eval):
batch_size, num_frames, height, width, channels = 2, 5, 128, 128, 3
video_shape = (batch_size, num_frames, height, width, channels)
np.random.seed(12345)
input_batch = py_utils.NestedMap(
video=np.random.randint(0, 255, size=video_shape)
)
predictions = py_utils.NestedMap(
reconstructed=np.random.normal(size=video_shape),
logits_real=np.random.normal(size=(batch_size, 1)),
logits_fake=np.random.normal(size=(batch_size, 1)),
quantizer_loss=np.random.normal(size=[]),
r1_gradient_penalty=np.random.normal(size=[]),
)

loss_p = pax_fiddle.Config(
losses.VQGANLoss,
name='loss',
)
loss_layer = loss_p.Instantiate()
prng_key = jax.random.PRNGKey(seed=123)
context_p = base_layer.JaxContext.HParams(do_eval=do_eval)
with base_layer.JaxContext.new_context(hparams=context_p):
init_vars = loss_layer.init(prng_key, predictions, input_batch)
loss_dict, updated_vars = loss_layer.apply(
init_vars, predictions, input_batch, mutable=base_layer.NON_TRAINABLE
)
for loss in loss_dict.values():
self.assertEqual((), loss.shape)
self.assertNotEqual(
updated_vars[base_layer.NON_TRAINABLE]['ema_fake_pred'], 0.0
)
self.assertNotEqual(
updated_vars[base_layer.NON_TRAINABLE]['ema_real_pred'], 0.0
)


if __name__ == '__main__':
absltest.main()

0 comments on commit 2080c2a

Please sign in to comment.