-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
68 lines (56 loc) · 2.35 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# compile model architecture
def dice_coef(y_true, y_pred, smooth=1):
intersection = K.sum(y_true * y_pred, axis=[1,2,3])
union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
return K.mean( (2. * intersection + smooth) / (union + smooth), axis=0)
def dice_loss(in_gt, in_pred):
return 1-dice_coef(in_gt, in_pred)
model = unet_model(1)
model.compile(optimizer='adam',
loss = dice_loss,
metrics=[dice_coef,'binary_accuracy'])
tf.keras.utils.plot_model(model, show_shapes=True)
# print model
model.summary()
# define visualization params
def visualize(display_list):
plt.figure(figsize=(12,12))
title = ['Input Image', 'True Mask', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
plt.axis('off')
plt.show()
# show some predictions before training
def show_predictions(sample_image, sample_mask):
pred_mask = model.predict(sample_image[tf.newaxis, ...])
pred_mask = pred_mask.reshape(img_size[0],img_size[1],1)
visualize([sample_image, sample_mask, pred_mask])
for i in range(5):
for images, masks in train_dataset.take(i):
for img, mask in zip(images, masks):
sample_image = img
sample_mask = mask
show_predictions(sample_image, sample_mask)
break
# train model
early_stop = tf.keras.callbacks.EarlyStopping(patience=4,restore_best_weights=True)
class DisplayCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
if (epoch + 1) % 3 == 0:
show_predictions(sample_image, sample_mask)
EPOCHS = 40
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
model_history = model.fit(train_dataset, epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
validation_data=valid_dataset,
callbacks=[DisplayCallback(), early_stop])
# predict on test data
for i in range(8):
for images, masks in test_dataset.take(i):
for img, mask in zip(images, masks):
tsample_image = img
tsample_mask = mask
show_predictions(tsample_image, tsample_mask)
break