Skip to content

Commit

Permalink
Update weights and EMD loss
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed Jan 5, 2018
1 parent 20df435 commit 1177b4f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
3 changes: 1 addition & 2 deletions evaluate_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
model = Model(base_model.input, x)
model.load_weights('weights/mobilenet_weights.h5')

img_path = 'images/img.png'
img_path = 'images/art1.jpg'
img = load_img(img_path)
x = img_to_array(img)

x = np.expand_dims(x, axis=0)

x = preprocess_input(x)
Expand Down
15 changes: 13 additions & 2 deletions train_mobilenet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os

from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras.optimizers import Adam
from keras import backend as K

from data_loader import train_generator, val_generator
Expand Down Expand Up @@ -47,7 +50,10 @@ def on_epoch_end(self, epoch, logs=None):
self.writer.flush()

def earth_mover_loss(y_true, y_pred):
return K.sqrt(K.mean(K.square(K.abs(K.cumsum(y_true, axis=-1) - K.cumsum(y_pred, axis=-1)))))
cdf_ytrue = K.cumsum(y_true, axis=-1)
cdf_ypred = K.cumsum(y_pred, axis=-1)
samplewise_emd = K.sqrt(K.mean(K.square(K.abs(cdf_ytrue - cdf_ypred)), axis=-1))
return K.mean(samplewise_emd)

image_size = 224

Expand All @@ -60,7 +66,12 @@ def earth_mover_loss(y_true, y_pred):

model = Model(base_model.input, x)
model.summary()
model.compile('adam', loss=earth_mover_loss)
optimizer = Adam(lr=1e-4)
model.compile(optimizer, loss=earth_mover_loss)

# load weights from trained model if it exists
if os.path.exists('weights/mobilenet_weights.h5'):
model.load_weights('weights/mobilenet_weights.h5')

checkpoint = ModelCheckpoint('weights/mobilenet_weights.h5', monitor='val_loss', verbose=1, save_weights_only=True, save_best_only=True,
mode='min')
Expand Down
Binary file modified weights/mobilenet_weights.h5
Binary file not shown.

0 comments on commit 1177b4f

Please sign in to comment.