Skip to content

Commit

Permalink
Stability and loss improvements for DPO
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 612828676
  • Loading branch information
The praxis Authors committed Mar 5, 2024
1 parent cad2a3a commit 3d015e5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 21 deletions.
27 changes: 8 additions & 19 deletions praxis/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,9 +1389,11 @@ def per_seq_log_p(softmax_out: NestedMap):
assert len(softmax_out.per_sequence_xent.shape) == 1
return -1.0 * softmax_out.per_sequence_xent

y_w_ref_log_p = per_seq_log_p(predictions.y_w_ref)
# Prevent backprop into reference model
y_w_ref_log_p = jax.lax.stop_gradient(per_seq_log_p(predictions.y_w_ref))
y_l_ref_log_p = jax.lax.stop_gradient(per_seq_log_p(predictions.y_l_ref))
# Allow backprop into policy model
y_w_pi_log_p = per_seq_log_p(predictions.y_w_pi)
y_l_ref_log_p = per_seq_log_p(predictions.y_l_ref)
y_l_pi_log_p = per_seq_log_p(predictions.y_l_pi)

self.add_summary('dpo/y_w_ref_log_p', jnp.mean(y_w_ref_log_p))
Expand All @@ -1403,30 +1405,19 @@ def per_seq_log_p(softmax_out: NestedMap):
add_hist(self, 'dpo/y_l_ref_log_p', y_l_ref_log_p)
add_hist(self, 'dpo/y_l_pi_log_p', y_l_pi_log_p)

r_hat_y_w = jax.lax.stop_gradient(
self.beta * (y_w_pi_log_p - y_w_ref_log_p)
)
r_hat_y_l = jax.lax.stop_gradient(
self.beta * (y_l_pi_log_p - y_l_ref_log_p)
)
r_hat_y_w = self.beta * (y_w_pi_log_p - y_w_ref_log_p)
r_hat_y_l = self.beta * (y_l_pi_log_p - y_l_ref_log_p)

# This is first equation on page #5 in the DPO paper.
per_example_loss = (
self.beta
* jax.nn.sigmoid(r_hat_y_l - r_hat_y_w)
* (y_l_pi_log_p - y_w_pi_log_p)
)
loss = jnp.mean(per_example_loss)
# This is the dpo_loss, same as what equation 7 in the paper computes.
dpo_loss = jnp.mean(-1.0 * jnp.log(jax.nn.sigmoid(r_hat_y_w - r_hat_y_l)))
loss = -1.0 * jnp.mean(jax.nn.log_sigmoid(r_hat_y_w - r_hat_y_l))

self.add_summary('dpo/r_hat_y_w', jnp.mean(r_hat_y_w))
self.add_summary('dpo/r_hat_y_w_std', jnp.std(r_hat_y_w))
self.add_summary('dpo/r_hat_y_l', jnp.mean(r_hat_y_l))
self.add_summary('dpo/r_hat_y_l_std', jnp.std(r_hat_y_l))
self.add_summary('dpo/delta_r_hat', jnp.mean(r_hat_y_w - r_hat_y_l))
self.add_summary('dpo/delta_r_hat_std', jnp.std(r_hat_y_w - r_hat_y_l))
self.add_summary('dpo/dpo_loss', dpo_loss)
self.add_summary('dpo/dpo_loss', loss)
self.add_summary(
'_dpo_topline/p_correct_ranking',
jnp.mean(jax.nn.sigmoid(r_hat_y_w - r_hat_y_l)),
Expand All @@ -1438,11 +1429,9 @@ def per_seq_log_p(softmax_out: NestedMap):
batch_size = predictions.y_l_ref.per_example_xent.shape[0]

# TODO(yonghui): Add diagnostic summaries.
# pair_loss is what learning back-props into.
return (
NestedMap(
total_loss=(loss, jnp.array(batch_size, loss.dtype)),
dpo_loss=(dpo_loss, jnp.array(batch_size, dpo_loss.dtype)),
),
{},
)
Expand Down
3 changes: 1 addition & 2 deletions praxis/layers/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,8 +1631,7 @@ def test_transformer_lm_dpo(self):
initial_vars = dpo_lm.init(prng_key, input_batch)
outputs, _ = dpo_lm.apply(initial_vars, input_batch)
logging.info('outputs: %s', outputs)
self.assertEqual(0.0, outputs.total_loss[0])
self.assertEqual(0.6931472, outputs.dpo_loss[0])
self.assertEqual(0.6931472, outputs.total_loss[0])


if __name__ == '__main__':
Expand Down

0 comments on commit 3d015e5

Please sign in to comment.