Skip to content

Commit

Permalink
fix a bug in calculating clarity loss as sample weight is specified; …
Browse files Browse the repository at this point in the history
…update v0.5.6
  • Loading branch information
[zebinyang] committed Oct 19, 2021
1 parent 8d99bc9 commit 153bed4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Binary file added dist/gaminet-0.5.6-py3-none-any.whl
Binary file not shown.
4 changes: 2 additions & 2 deletions gaminet/gaminet.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def call(self, inputs, sample_weight=None, main_effect_training=False, interacti
a2 = tf.multiply(tf.gather(self.maineffect_outputs, [k2], axis=1), tf.gather(main_weights, [k2], axis=0))
b = tf.multiply(tf.gather(self.interact_outputs, [i], axis=1), tf.gather(interaction_weights, [i], axis=0))
if sample_weight is not None:
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(tf.multiply(a1, b), sample_weight.reshape(-1, 1))))
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(tf.multiply(a2, b), sample_weight.reshape(-1, 1))))
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(tf.multiply(a1, b), tf.reshape(sample_weight, (-1, 1)))))
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(tf.multiply(a2, b), tf.reshape(sample_weight, (-1, 1)))))
else:
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(a1, b)))
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(a2, b)))
Expand Down

0 comments on commit 153bed4

Please sign in to comment.