Skip to content

Commit

Permalink
fix a bug in calculating clarity loss as sample weight is specified; …
Browse files Browse the repository at this point in the history
…update v0.5.6
  • Loading branch information
[zebinyang] committed Oct 19, 2021
1 parent bca0acf commit 8d99bc9
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 4 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The following environments are required:

- Python 3.7 + (anaconda is preferable)
- tensorflow>=2.0.0
- tensorflow-lattice>=2.0.8
- numpy>=1.15.2
- pandas>=0.19.2
- matplotlib>=3.1.3
Expand All @@ -16,6 +17,8 @@ The following environments are required:
pip install gaminet
```

To use it on GPU, conda install tensorflow==2.2, pip install tensorflow-lattice==2.0.8, conda install tensorflow-estimators==2.2

## Usage

Import library
Expand Down
2 changes: 1 addition & 1 deletion gaminet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

__all__ = ["GAMINet"]

__version__ = '0.5.5'
__version__ = '0.5.6'
__author__ = 'Zebin Yang and Aijun Zhang'
4 changes: 2 additions & 2 deletions gaminet/gaminet.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def call(self, inputs, sample_weight=None, main_effect_training=False, interacti
a2 = tf.multiply(tf.gather(self.maineffect_outputs, [k2], axis=1), tf.gather(main_weights, [k2], axis=0))
b = tf.multiply(tf.gather(self.interact_outputs, [i], axis=1), tf.gather(interaction_weights, [i], axis=0))
if sample_weight is not None:
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(tf.multiply(a1, b), sample_weight)))
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(tf.multiply(a2, b), sample_weight)))
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(tf.multiply(a1, b), sample_weight.reshape(-1, 1))))
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(tf.multiply(a2, b), sample_weight.reshape(-1, 1))))
else:
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(a1, b)))
self.clarity_loss += tf.abs(tf.reduce_mean(tf.multiply(a2, b)))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
}

setup(name='gaminet',
version='0.5.5',
version='0.5.6',
description='Explainable Neural Networks based on Generalized Additive Models with Structured Interactions',
url='https://github.com/ZebinYang/GAMINet',
author='Zebin Yang',
Expand Down

0 comments on commit 8d99bc9

Please sign in to comment.