-
Notifications
You must be signed in to change notification settings - Fork 5
/
seg_utils.py
1182 lines (910 loc) · 36.1 KB
/
seg_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
A collection of classes and functions for training fully-convolutional
CNNs for semantic segmentation. Includes a custom generator class for
performing model.fit_genertor() method in Keras. This is specifically used
for segementation ground truth labels which requires sample weights
to be used. (See https://github.com/keras-team/keras/issues/3653)
Author: Simon Thomas
Email: [email protected]
Start Date: 24/10/18
Last Update: 24/06/19
"""
import numpy as np
from sys import stderr
import h5py
import os
import io as IO
import skimage.io as io
from skimage.measure import regionprops
from skimage.morphology import disk
from skimage.transform import rotate
from skimage.filters import median
from sklearn.linear_model import LinearRegression
from cv2 import resize, cvtColor, COLOR_RGB2BGR, imwrite
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as cols
from matplotlib.pyplot import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from keras.callbacks import Callback, TensorBoard
from keras.models import Model
import keras.backend as K
import tensorflow as tf
from sklearn.utils.class_weight import compute_class_weight
from pandas_ml import ConfusionMatrix
class Palette(object):
"""
A color pallete which is essentially a channel
colour LUT.
"""
def __init__(self, ordered_list):
"""
Takes and order list of colours and stores in dictionary
format.
Input:
ordered_list - list of rgb tuples in class order
Output:
self[index] - rgb tuple associated with index/class
"""
self.colors = dict((i, color) for (i, color) in enumerate(ordered_list))
def __getitem__(self, arg):
"""
Returns item with input key, i.e. channel
"""
return self.colors[arg]
def __str__(self):
"""
Print representation
"""
return "Channel Color Palette:\n" + str(self.colors)
def __repr__(self):
return self.__str__()
def __len__(self):
return len(self.colors.keys())
class SegmentationGen(object):
"""
A generator that returns X, y & sampe_weight data in designated batch sizes,
specifically for segmentation problems. It converts y images 2D arrays
to work with sample_weights which are calculated on the fly for each batch.
The output predictions are in shape (batch_size, dim*dim*, num_classes)
and therefore need to be reshaped inorder to be interpreted / visualised.
Validation needs to be done separately due to implementation differences
with keras, but remains fairly straight forward.
Training is very sensitive to batchsize and traininging rate.
See again: https://github.com/keras-team/keras/issues/3653
Example:
>>> colours = [(0,0,255),(0,255,0),(255,0,0)]
>>> palette = Palette(colours)
>>> batch_size = 10
>>> dim = 128
>>> train_gen = SegmentationGen(batch_size, X_dir, y_dir, palette,
... x_dim=dim, y_dim=dim)
>>> # Compile mode - include sampe_weights_mode="temporal"
... model.compile( optimizer=SGD(lr=0.001),
... loss="categorical_crossentropy",
... sample_weight_mode="temporal",
... metrics=["accuracy"])
>>> # Train
... history = model.fit_generator(
... generator = train_gen,
... steps_per_epoch = train_gen.n // batch_size
... validation_data = val_gen,
... validation_steps = val_gen.n // batch_size )
>>> # Evaluate
... loss, acc = model.evaluate_generator(
... generator = test_gen,
... steps = 2)
Input:
batch_size - number of images in a batch
X_dir - full path directory of training images
y_dir - full path directory of training labels
palette - color palette object where each index (range(n-classes))
is the class colour from the segmented ground truth. Get be
obtained from the LUT of come standard segmentaiton datasets.
dim - batches require images to be stacked so for
batch_size > 1 image_size is required.
suffix - the image type in the raw images. Default is ".png"
weight_mod - a dictionary to modify certain weights by index
i.e. weight_mod = {0 : 1.02} increases weight 0 by 2%.
Default is None.
Output:
using the global next() function or internal next() function the class
returns X_train, y_train numpy arrays:
X_train.shape = (batch_size, image_size, dim, 3)
y_train.shape = (batch_size, image_size, dim, num_classes)
"""
def __init__(self,
batch_size, X_dir, y_dir,
palette, x_dim, y_dim,
suffix=".png", weight_mod=None):
self.batch_size = batch_size
self.X_dir = X_dir
self.y_dir = y_dir
self.palette = palette
self.x_dim = x_dim
self.y_dim = y_dim
self.suffix = suffix
self.weight_mod = weight_mod
self.files = np.array(os.listdir(X_dir))
self.num_classes = len(palette)
self.n = len(self.files)
self.cur = 0
self.order = list(range(self.n))
np.random.shuffle(self.order)
self.files_in_batch = None
# Helper functions
def _getClassMask(self, rgb, image):
"""
Takes an rgb tuple and returns a binary mask of size
im.shape[0] x im.shape[1] indicated where each color
is present.
Input:
rgb - tuple of (r, g, b)
image - segmentation ground truth image
Output:
mask - binary mask
"""
# Colour mask
if len(rgb) == 3:
return np.all(image == rgb, axis=-1)
#r, g, b = rgb
#r_mask = im[:,:, 0] == r
#g_mask = im[:,:, 1] == g
#b_mask = im[:,:, 2] == b
#mask = r_mask & g_mask & b_mask
#return mask
# 8-bit mask
return image[:, :] == rgb
def _calculateWeights(self, y_train):
"""
Calculates the balanced weights of all the classes
in the batch.
Input:
y_train - (dim, dim,num_classes) ground truth
Ouput:
weights - a list of the weights for each class
"""
class_counts = []
# loop through each class
for i in range(self.num_classes):
batch_count = 0
# Sum up each class count in each batch image
for b in range(y_train.shape[0]):
batch_count += np.sum(y_train[b][:,:,i])
class_counts.append(batch_count)
# create Counts
y = []
present_classes = []
absent_classes = []
for i in range(self.num_classes):
# Adjusts for absence
if class_counts[i] == 0:
absent_classes.append(i)
continue
else:
present_classes.append(i)
y.extend([i]*int(class_counts[i]))
# Calculate weights
weights = compute_class_weight("balanced", present_classes, y)
for c in absent_classes:
weights = np.insert(weights, c, 0)
# Modify weight for a particular class
if self.weight_mod:
for key in self.weight_mod.keys():
weights[key] *= self.weight_mod[key]
return weights
def _createBatches(self, positions):
"""
Creates X_train and y_train batches from the given
positions i.e. files in the directory
Input:
positions - list of integers representing files
Output:
X_train, y_train - batches
"""
# Store images in batch
X_batch = []
y_batch = []
# Save file names for batch
self.files_in_batch = self.files[positions]
# Loop through current batch
for pos in positions:
# Get image name
fname = self.files[pos][:-4]
# load X-image
im = io.imread(os.path.join(self.X_dir, fname + self.suffix))[:,:,0:3] # drop alpha
im = resize(im, (self.x_dim, self.x_dim))
X_batch.append(im)
# Load y-image
im = io.imread(os.path.join(self.y_dir, fname + ".png"))[:,:,0:3] # drop alpha
im = resize(im, (self.y_dim, self.y_dim))
# Convert to 3D ground truth
y = np.zeros((im.shape[0], im.shape[1], self.num_classes), dtype=np.float32)
# Loop through colors in palette and assign to new array
for i in range(self.num_classes):
rgb = self.palette[i]
mask = self._getClassMask(rgb, im)
y[mask, i] = 1.
y_batch.append(y)
# Combine images into batches and normalise
X_train = np.stack(X_batch, axis=0).astype(np.float32)
y_train = np.stack(y_batch, axis=0)
# Preprocess X_train
X_train /= 255.
X_train -= 0.5
X_train *= 2.
# Calculate sample weights
weights = self._calculateWeights(y_train)
# Take weight for each correct position
sample_weights = np.take(weights, np.argmax(y_train, axis=-1))
# Reshape to suit keras
sample_weights = sample_weights.reshape(y_train.shape[0], self.y_dim*self.y_dim)
y_train = y_train.reshape(y_train.shape[0],
self.y_dim*self.y_dim,
self.num_classes)
return X_train, y_train, sample_weights
def __next__(self):
"""
Returns a batch when the `next()` function is called on it.
"""
while True:
# Most batches will be equal to batch_size
if self.cur < (self.n - self.batch_size):
# Get positions of files in batch
positions = self.order[self.cur:self.cur + self.batch_size]
self.cur += self.batch_size
# create Batches
X_train, y_train, sample_weights = self._createBatches(positions)
return (X_train, y_train, sample_weights)
# Final batch is smaller than batch_size
else:
# Have sufficient data in each batch is good on multi-GPUs
np.random.shuffle(self.order)
self.cur = 0
continue
def predict_image(model, image):
"""
Simplifies image prediction for segmentation models. Automatically
reshapes output so it can be visualised.
Input:
model - ResNet training model where model.layers[-1] is a reshape
layer.
image - rgb image of shape (dim, dim, 3) where dim == model.input_shape
image should already be pre-processed using load_image() function.
Output:
preds - probability heatmap of shape (dim, dim, num_classes)
class_img - argmax of preds of shape (dim, dim, 1)
"""
if len(image.shape) < 4:
# Add new axis to conform to model input
image = image[np.newaxis, ::]
# Prediction
preds = model.predict(image)[0].reshape(
image.shape[0],
image.shape[0],
model.layers[-1].output_shape[-1])
# class_img
class_img = np.argmax(preds, axis=-1)
return preds, class_img
def get_color_map(colors):
"""
Returns a matplotlib color map of the list of RGB values
Input:
colors - a list of RGB colors
Output:
cmap - a matplotlib color map object
"""
# Normalise RGBs
norm_colors = []
for color in colors:
norm_colors.append([val / 255. for val in color])
# create color map
cmap = cols.ListedColormap(norm_colors)
return cmap
def apply_color_map(colors, image):
"""
Applies the color specified by colors to the input image.
Input:
colors - list of colors in color map
image - image to apply color map to with shape (n, n)
Output:
color_image - image with shape (n, n, 3)
"""
cmap = get_color_map(colors)
norm = Normalize(vmin=0, vmax=len(colors))
color_image = cmap(norm(image))[:, :, 0:3] # drop alpha
return color_image
def load_image(fname, pre=True):
"""
Loads an image, with optional resize and pre-processing
for ResNet50.
Input:
fname - path + name of file to load
pre - whether to pre-process image
Output:
im - image as numpy array
"""
im = io.imread(fname).astype("float32")[:, :, 0:3]
if pre:
im /= 255.
im -= 0.5
im *= 2.
return im
def set_weights_for_training(model, fine_tune, layer_num=[81, 174]):
"""
Takes a model and a training state i.e. fine_tune = True
and sets weights accordingly. Fine-tuning unlocks
from layer 81 - res4a_branch2a
Input:
model - ResNet_UNet model by default, can be any model
fine_tune - bool to signify training state
layer_num - layer to lock/unlock from. default is
173 add_16, where 174 is up_sampling2d_1
Output:
None
"""
if not fine_tune:
print("[INFO] base model...")
# ResNet layers
for layer in model.layers[0:layer_num[1]]:
# Opens up mean and variance for training
if hasattr(layer, 'moving_mean') and hasattr(layer, 'moving_variance'):
layer.trainable = True
K.eval(K.update(layer.moving_mean, K.zeros_like(layer.moving_mean)))
K.eval(K.update(layer.moving_variance, K.zeros_like(layer.moving_variance)))
else:
layer.trainable = False
# UNet layers
for layer in model.layers[layer_num[1]::]:
layer.trainable = True
else:
print("[INFO] fine tuning model...")
# ResNet layers
for layer in model.layers[layer_num[0]:layer_num[1]]:
layer.trainable = True
# Opens up mean and variance for training
if hasattr(layer, 'moving_mean') and hasattr(layer, 'moving_variance'):
K.eval(K.update(layer.moving_mean, K.zeros_like(layer.moving_mean)))
K.eval(K.update(layer.moving_variance, K.zeros_like(layer.moving_variance)))
# UNet layers
for layer in model.layers[layer_num[1]::]:
layer.trainable = True
def get_number_of_images(dir):
"""
Returns number of files in given directory
Input:
dir - full path of directory
Output:
number of files in directory
"""
return len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])
def load_multigpu_checkpoint_weights(model, h5py_file):
"""
Loads the weights of a weight checkpoint from a multi-gpu
keras model.
Input:
model - keras model to load weights into
h5py_file - path to the h5py weights file
Output:
None
"""
print("Setting weights...")
with h5py.File(h5py_file, "r") as file:
# Get model subset in file - other layers are empty
weight_file = file["model_1"]
for layer in model.layers:
try:
layer_weights = weight_file[layer.name]
except:
# No weights saved for layer
continue
try:
weights = []
# Extract weights
for term in layer_weights:
if isinstance(layer_weights[term], h5py.Dataset):
# Convert weights to numpy array and prepend to list
weights.insert(0, np.array(layer_weights[term]))
# Load weights to model
layer.set_weights(weights)
except Exception as e:
print("Error: Could not load weights for layer:", layer.name, file=stderr)
def create_prob_map_from_mask(filename, palette):
"""
Creates a probability map with the input mask
Input:
filename - path to mask file
pallette - color palette of mask
Output:
prob_map - numpy array of size image.h x image.w x num_classes
"""
# Helper functions
def _get_class_mask(rgb, im):
"""
Takes an rgb tuple and returns a binary mask of size
im.shape[0] x im.shape[1] indicated where each color
is present.
Input:
rgb - tuple of (r, g, b)
im - segmentation ground truth image
Output:
mask - binary mask
"""
# Colour mask
if len(rgb) == 3:
r, g, b = rgb
r_mask = im[:, :, 0] == r
g_mask = im[:, :, 1] == g
b_mask = im[:, :, 2] == b
mask = r_mask & g_mask & b_mask
return mask
# 8-bit mask
return im[:, :] == rgb
# -------------------------#
num_classes = len(palette)
# Load y-image
im = io.imread(filename)
# Convert to 3D ground truth
prob_map = np.zeros((im.shape[0], im.shape[1], num_classes), dtype=np.float32)
# Loop through colors in palette and assign to new array
for i in range(num_classes):
rgb = palette[i]
mask = _get_class_mask(rgb, im)
prob_map[mask, i] = 1.
return prob_map
# def generate_ROC_AUC(true_map, prob_map, color_dict, colors):
# """
# Generates ROC curves and AUC values for all class in image, as well
# as keeps raw data for later use.
#
# Input:
# true_map - map of true values, generated from mask using
# create_prob_map_from_mask()
# prob map - 3 dimensional prob_map created from model.predict()
#
# color_dict - color dictionary containing names and colors
#
# colors - list of colors
# Output:
#
# ROC - dictionary:
# "AUC" - scalar AUC value
# "TPR" - array of trp for different thresholds
# "FPR" - array of fpr for different thresholds
# "raw_data" - type of (true, pred) where each are arrays
#
# ! NEED TO INCLUDE SAMPLE WEIGHTS !
#
# """
# # Create ROC curves for all tissue types
# ROC = {}
# for tissue_class in color_dict.keys():
# # Get class index
# class_idx = colors.index(color_dict[tissue_class])
#
# true = np.ravel(true_map[:, :, class_idx])
# pred = np.ravel(prob_map[:, :, class_idx])
#
# # Get FPR and TPR
# fpr, tpr, thresholds = roc_curve(true, pred)
# roc_auc = auc(fpr, tpr)
# if np.isnan(roc_auc):
# # class not present
# continue
# # Update values
# ROC[tissue_class] = {"AUC": roc_auc, "TPR": tpr, "FPR": fpr, "raw_data": (true, pred)}
#
# return ROC
def calculate_tile_size(image_shape, lower=50, upper=150):
"""
Calculates a tile size with optimal overlap
Input:
image - original histo image (large size)
lower - lowerbound threshold for overlap
upper - upper-bound threshold for overlap
Output:
dim - dimension of tile
threshold - calculated overlap for tile and input image
"""
def smallest_non_zero(values, threshold=10):
for tile, overlap in values:
if overlap > threshold:
return tile, overlap
dims = [x*(2**5) for x in range(6, 45)]
w = image_shape[1]
h = image_shape[0]
thresholds = {}
for d in dims:
w_steps = w // d
if w_steps == 0:
continue
w_overlap = (d - (w % d)) // w_steps
h_steps = h // d
if h_steps == 0:
continue
h_overlap = (d - (h % d)) // h_steps
# Threshold is half the minimum overlap
thresholds[d] = min(w_overlap, h_overlap) // 2
# Loop through pairs and take first that satisfies
sorted_thresholds = sorted(thresholds.items(), key=lambda x: x[1])
for d, t in sorted_thresholds:
if lower < t < upper:
return w, h, d, t # dim, threshold
# Else - get largest overlap value
print("[INFO] - title overlap threshold not met. Defaulting to Smallest non-zero overlap")
return w, h, smallest_non_zero(sorted_thresholds, threshold=10)
def whole_image_predict(files, model, output_directory, colors, compare=True, pad_val=50, prob_map=False):
"""
Generates a segmentation mask for each of the images in files
and saves them in the output directory.
Input:
files - list of files including their full paths
model - model to use to predict.
>> Must be of type K.function NOT keras.models.Model
>> Must have output shape (1, dim, dim, 12)
output_directory - path to directory to save files, include / on end
colors - list of RGB values to apply to segmentation
Output:
None
"""
image_num = 1
for file in files:
# Get name
name = file.split("/")[-1].split(".")[0]
print("Whole Image Segmentation:", name, "Num:", image_num, "of", len(files))
image_num += 1
try:
# Load image
histo = load_image(file, pre=True)
# Pad image with minimum threshold
# https://stackoverflow.com/questions/35751306/python-how-to-pad-numpy-array-with-zeros
histo = np.pad(histo, [(pad_val, pad_val),(pad_val, pad_val), (0, 0)], mode="constant", constant_values=0.99)
# Create canvas to add predictions
if prob_map:
canvas = np.zeros((histo.shape[0],histo.shape[1], len(colors)))
else:
canvas = np.zeros_like(histo)
# Tile info
w, h, dim, threshold = calculate_tile_size(histo.shape, lower=50, upper=100)
print("Tile size:", dim)
print("Tile threshold:", threshold)
except Exception as e:
print("Failed to process:", name, e, file=stderr)
continue
# Compute number of vertical and horizontal steps
w_steps = w // dim
w_overlap = (dim - (w % dim)) // w_steps
h_steps = h // dim
h_overlap = (dim - (h % dim)) // h_steps
# starting positions
w_x, w_y = 0, dim
h_x, h_y = 0, dim
# Loop through all tiles and predict
step = 1
for i in range(h_steps + 1):
for j in range(w_steps + 1):
print("Processing tile", step, "of", (h_steps + 1) * (w_steps + 1), )
step += 1
# Grab a tile
tile = histo[h_x:h_y, w_x:w_y, :][np.newaxis, ::]
# Check and correct shape
orig_shape = tile.shape
if prob_map:
if tile.shape != (dim, dim, len(colors)):
tile = resize(tile[0], dsize=(dim, dim))[np.newaxis, ::]
# Predict
probs = model([tile])[0]
# Add prediction to canvas
canvas[h_x + threshold: h_y - threshold,
w_x + threshold: w_y - threshold, :] = probs[0][threshold:-threshold,
threshold:-threshold, :]
else:
if tile.shape != (dim, dim, 3):
tile = resize(tile[0], dsize=(dim, dim))[np.newaxis, ::]
# Predict
probs = model([tile])[0]
class_pred = np.argmax(probs[0], axis=-1)
segmentation = apply_color_map(colors, class_pred)
# Add prediction to canvas
canvas[h_x + threshold: h_y - threshold,
w_x + threshold: w_y - threshold, :] = segmentation[threshold:-threshold,
threshold:-threshold, :]
# Update column positions
w_x += dim - w_overlap
w_y += dim - w_overlap
# Update row positions
h_x += dim - h_overlap
h_y += dim - h_overlap
w_x, w_y = 0, dim
# Save Segmentation
fname = output_directory + name + ".png"
# Crop canvas by removing padding
canvas = canvas[pad_val:-pad_val, pad_val:-pad_val, :]
if compare:
# Load in ground truth
file = "/".join(file.split("/")[0:-2]) + "/Masks/" + name + ".png"
mask = io.imread(file)
fig, axes = plt.subplots(1, 2, figsize=(12, 8), frameon=False)
axes[0].imshow(mask)
axes[0].set_title("Ground Truth")
axes[0].set_axis_off()
axes[1].imshow(canvas)
axes[1].set_title("Predicted")
axes[1].set_axis_off()
plt.tight_layout()
plt.savefig(fname, dpi=300)
plt.close()
# Wipe canvas from memory
del canvas
elif prob_map:
file = fname.split(".")[0] + ".npy"
np.save(file, canvas)
del canvas
else:
# Scale values to RGB
#canvas *= 255.
# Convert canvas to BGR color space for cv2.imwrite
#canvas = cvtColor(canvas, COLOR_RGB2BGR)
#print("saving...", fname)
#imwrite(fname, canvas)
io.imsave(fname, canvas)
# Wipe canvas from memory
del canvas
class Validation(TensorBoard):
"""
A custom callback to perform validation at the
end of each epoch. Also writes useful class
metrics to tensorboard logs.
"""
def __init__(self, generator, steps, classes, run_name,
color_list, WSI=False, model_to_save=None,
weight_path=None, interval=5,
**kwargs):
"""
Initialises the callback
Input:
generator - validation generator of type SegmentationGen()
steps - number of steps in validation e.g. n // batch_size
classes - an ordered list of classes ie. [ "EPI", "GLD" etc ]
run_name - str of the unique run identifier
color_list - list of RGB values for applying colors to predictions
model_to_save - model to save weights as checkpoint (useful for multi-GPU models)
where the model passed should be the non-parallel model.
log_dir - Tensorboard log directory
"""
super().__init__(**kwargs)
self.validation_data = generator
self.validation_steps = steps
self.classes = np.asarray(classes)
self.cms = []
self.run_name = run_name
self.color_list = color_list
self.WSI = WSI
self.model_to_save = model_to_save
self.weight_path = weight_path
self.interval = interval
# Helper functions ------------------------------------------------ #
def write_confusion_matrix_to_buffer(self, matrix, classes):
"""
Writes a confusion matrix to the tensorboard session
Input:
matrix - numpy confusion matrix
classes - ordered list of classes
Output:
buffer - buffer where plot is written to
"""
# Compute row sums for Recall
row_sums = matrix.sum(axis=1)
matrix = np.round(matrix / row_sums[:, np.newaxis], 3)
# Import colors
color = [255, 118, 25]
orange = [c / 255. for c in color]
white_orange = LinearSegmentedColormap.from_list("", ["white", orange])
fig = plt.figure(figsize=(12, 14))
ax = fig.add_subplot(111)
cax = ax.matshow(matrix, interpolation='nearest', cmap=white_orange)
fig.colorbar(cax)
ax.set_xticklabels([''] + classes, fontsize=8)
ax.set_yticklabels([''] + classes, fontsize=8)
# Get ticks to show properly
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(1))
ax.set_title("Recall")
ax.set_ylabel("Ground Truth")
ax.set_xlabel("Predicted")
for i in range(len(classes)):
for j in range(len(classes)):
ax.text(j - 0.1, i, str(matrix[i, j]), fontsize=8)
buffer = IO.BytesIO()
plt.savefig(buffer, format="png")
buffer.seek(0)
plt.close(fig)
return buffer
def write_current_predict(self, mask, prediction, image_num, epoch):
"""
Write mask and prediction to Tensorboard
"""
fig, axes = plt.subplots(1, 2)
axes[0].imshow(apply_color_map(self.color_list, mask))
axes[0].set_title("Ground Truth")
plt.axis('off')
axes[1].imshow(apply_color_map(self.color_list, prediction))
axes[1].set_title("Predict")
plt.axis('off')
# save to buffer
plot_buffer = IO.BytesIO()
plt.savefig(plot_buffer, format="png")
plot_buffer.seek(0)
plt.close(fig)
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=plot_buffer.getvalue(),
height=mask.shape[0],
width=mask.shape[1])
# Create a Summary value
im_summary = tf.Summary.Value(image=img_sum, tag="Segmentation/Segmentation_E_" + str(epoch))
summary = tf.Summary(value=[im_summary])
self.writer.add_summary(summary, str(epoch))
def write_current_cm(self, epoch):
"""
Write confusion matrix to Tensorboard
"""
# Get the matrix
matrix = self.cms[-1]
# Prepare the plot
plot_buffer = self.write_confusion_matrix_to_buffer(matrix, list(self.classes))
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=plot_buffer.getvalue(),
height=800,
width=800)
# Create a Summary value
im_summary = tf.Summary.Value(image=img_sum, tag="Confusion_Matrix")
summary = tf.Summary(value=[im_summary])
self.writer.add_summary(summary, str(epoch))
def compute_stats(self, epoch_cm, logs):
"""