From 99b61bc6a9cf810ab06b8a442054ea3f0d3b2203 Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 21 Aug 2024 12:08:56 -0700 Subject: [PATCH] add more postprocessing discussion --- solution.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/solution.py b/solution.py index 56e5358..3b49bf9 100644 --- a/solution.py +++ b/solution.py @@ -1000,11 +1000,12 @@ def create_aff_target(self, mask): # Training will break if you change the number of affinities. It is a simple fix, you will just need to change the number # of output channels the unet produces. +neighborhood = [[0, 1], [1, 0], [0, 5], [5, 0]] train_data = AffinityDataset( "tissuenet_data/train", v2.RandomCrop(256), weights=True, - neighborhood=[[0, 1], [1, 0], [0, 5], [5, 0]], + neighborhood=neighborhood, ) train_loader = DataLoader( train_data, batch_size=5, shuffle=True, num_workers=NUM_THREADS @@ -1039,7 +1040,7 @@ def create_aff_target(self, mask): unet = UNet( depth=4, in_channels=2, - out_channels=4, + out_channels=len(neighborhood), final_activation=torch.nn.Sigmoid(), num_fmaps=16, fmap_inc_factor=3, @@ -1078,6 +1079,10 @@ def create_aff_target(self, mask): # %% [markdown] # Let's next look at a prediction on a random image. +# We will be using mutex watershed (see this paper by [Wolf et al.](https://arxiv.org/abs/1904.12654)) for post processing. I won't dive too much into the details, but it is similar to watershed except that it allows edges to have negative weights and for splits, removing the need for finding seed points. +# However this does mean we now need a bias term since if we give it all positive edges (our affinities are in range (0, 1)) everything will join into a single object. Thus our bias should be in range (-1, 0), such that we have some positive and some negative affinities. +# +# It can also be useful to bias long range affinities more negatively than the short range affinities. The intuition here being that boundaries are often blurry in biology. This means it may not be easy to tell if the neighboring pixel has crossed a boundary, but it is reasonably easy to tell if there is a boundary accross a 5 pixel gap. Similarly, identifying if two pixels belong to the same object is easier, the closer they are to each other. Providing a more negative bias to long range affinities means we bias towards splitting on low long range affinities, and merging on high short range affinities. # %% val_data = AffinityDataset("tissuenet_data/test", v2.RandomCrop(256), return_mask=True) @@ -1105,15 +1110,19 @@ def create_aff_target(self, mask): pred[3] + bias_long, ] ).astype(np.float64), - [[0, 1], [1, 0], [0, 5], [5, 0]], + neighborhood, ) +# Mutex watershed often leads to many tiny fragments due to the fuzziness of our models predictions. +# We can add a simple small object filter to get significantly higher accuracy. +precision, recall, accuracy = evaluate(gt_labels, pred_labels) +print(f"Before filter: Precision: {precision:.3f}, Recall: {recall:.3f}, Accuracy: {accuracy:.3f}") plot_four(image, affs, pred, pred_labels, label="Affinity", cmap=label_cmap) -print(evaluate(gt_labels, pred_labels)) pred_labels = remove_small_objects( pred_labels.astype(np.int64), min_size=64, connectivity=1 ) -print(evaluate(gt_labels, pred_labels)) +precision, recall, accuracy = evaluate(gt_labels, pred_labels) +print(f"After filter: Precision: {precision:.3f}, Recall: {recall:.3f}, Accuracy: {accuracy:.3f}") plot_four(image, affs, pred, pred_labels, label="Affinity", cmap=label_cmap) # %% [markdown] # Let's also evaluate the model performance. @@ -1160,13 +1169,16 @@ def create_aff_target(self, mask): pred_labels = mws.agglom( np.array( [ - pred[0] - bias_short, - pred[1] - bias_short, - pred[2] - bias_long, - pred[3] - bias_long, + pred[0] + bias_short, + pred[1] + bias_short, + pred[2] + bias_long, + pred[3] + bias_long, ] ).astype(np.float64), - [[0, 1], [1, 0], [0, 5], [5, 0]], + neighborhood, + ) + pred_labels = remove_small_objects( + pred_labels.astype(np.int64), min_size=64, connectivity=1 ) precision, recall, accuracy = evaluate(gt_labels, pred_labels)