Skip to content

Commit

Permalink
add more postprocessing discussion
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Aug 21, 2024
1 parent 86269b0 commit 99b61bc
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,11 +1000,12 @@ def create_aff_target(self, mask):
# Training will break if you change the number of affinities. It is a simple fix, you will just need to change the number
# of output channels the unet produces.

neighborhood = [[0, 1], [1, 0], [0, 5], [5, 0]]
train_data = AffinityDataset(
"tissuenet_data/train",
v2.RandomCrop(256),
weights=True,
neighborhood=[[0, 1], [1, 0], [0, 5], [5, 0]],
neighborhood=neighborhood,
)
train_loader = DataLoader(
train_data, batch_size=5, shuffle=True, num_workers=NUM_THREADS
Expand Down Expand Up @@ -1039,7 +1040,7 @@ def create_aff_target(self, mask):
unet = UNet(
depth=4,
in_channels=2,
out_channels=4,
out_channels=len(neighborhood),
final_activation=torch.nn.Sigmoid(),
num_fmaps=16,
fmap_inc_factor=3,
Expand Down Expand Up @@ -1078,6 +1079,10 @@ def create_aff_target(self, mask):

# %% [markdown]
# Let's next look at a prediction on a random image.
# We will be using mutex watershed (see this paper by [Wolf et al.](https://arxiv.org/abs/1904.12654)) for post processing. I won't dive too much into the details, but it is similar to watershed except that it allows edges to have negative weights and for splits, removing the need for finding seed points.
# However this does mean we now need a bias term since if we give it all positive edges (our affinities are in range (0, 1)) everything will join into a single object. Thus our bias should be in range (-1, 0), such that we have some positive and some negative affinities.
#
# It can also be useful to bias long range affinities more negatively than the short range affinities. The intuition here being that boundaries are often blurry in biology. This means it may not be easy to tell if the neighboring pixel has crossed a boundary, but it is reasonably easy to tell if there is a boundary accross a 5 pixel gap. Similarly, identifying if two pixels belong to the same object is easier, the closer they are to each other. Providing a more negative bias to long range affinities means we bias towards splitting on low long range affinities, and merging on high short range affinities.

# %%
val_data = AffinityDataset("tissuenet_data/test", v2.RandomCrop(256), return_mask=True)
Expand Down Expand Up @@ -1105,15 +1110,19 @@ def create_aff_target(self, mask):
pred[3] + bias_long,
]
).astype(np.float64),
[[0, 1], [1, 0], [0, 5], [5, 0]],
neighborhood,
)

# Mutex watershed often leads to many tiny fragments due to the fuzziness of our models predictions.
# We can add a simple small object filter to get significantly higher accuracy.
precision, recall, accuracy = evaluate(gt_labels, pred_labels)
print(f"Before filter: Precision: {precision:.3f}, Recall: {recall:.3f}, Accuracy: {accuracy:.3f}")
plot_four(image, affs, pred, pred_labels, label="Affinity", cmap=label_cmap)
print(evaluate(gt_labels, pred_labels))
pred_labels = remove_small_objects(
pred_labels.astype(np.int64), min_size=64, connectivity=1
)
print(evaluate(gt_labels, pred_labels))
precision, recall, accuracy = evaluate(gt_labels, pred_labels)
print(f"After filter: Precision: {precision:.3f}, Recall: {recall:.3f}, Accuracy: {accuracy:.3f}")
plot_four(image, affs, pred, pred_labels, label="Affinity", cmap=label_cmap)
# %% [markdown]
# Let's also evaluate the model performance.
Expand Down Expand Up @@ -1160,13 +1169,16 @@ def create_aff_target(self, mask):
pred_labels = mws.agglom(
np.array(
[
pred[0] - bias_short,
pred[1] - bias_short,
pred[2] - bias_long,
pred[3] - bias_long,
pred[0] + bias_short,
pred[1] + bias_short,
pred[2] + bias_long,
pred[3] + bias_long,
]
).astype(np.float64),
[[0, 1], [1, 0], [0, 5], [5, 0]],
neighborhood,
)
pred_labels = remove_small_objects(
pred_labels.astype(np.int64), min_size=64, connectivity=1
)

precision, recall, accuracy = evaluate(gt_labels, pred_labels)
Expand Down

0 comments on commit 99b61bc

Please sign in to comment.