-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsolution.py
2097 lines (1814 loc) · 75.1 KB
/
solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# %% [markdown] tags=[]
# # Image translation (Virtual Staining)
# Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco
# ## Overview
#
# In this exercise, we will _virtually stain_ the nuclei and plasma membrane from the quantitative phase image (QPI), i.e., translate QPI images into fluoresence images of nuclei and plasma membranes.
# QPI encodes multiple cellular structures and virtual staining decomposes these structures. After the model is trained, one only needs to acquire label-free QPI data.
# This strategy solves the problem as "multi-spectral imaging", but is more compatible with live cell imaging and high-throughput screening.
# Virtual staining is often a step towards multiple downstream analyses: segmentation, tracking, and cell state phenotyping.
#
# In this exercise, you will:
# - Train a model to predict the fluorescence images of nuclei and plasma membranes from QPI images
# - Make it robust to variations in imaging conditions using data augmentions
# - Segment the cells
# - Use regression and segmentation metrics to evalute the models
# - Visualize the image transform learned by the model
# - Understand the failure modes of the trained model
#
# [![HEK293T](https://raw.githubusercontent.com/mehta-lab/VisCy/main/docs/figures/svideo_1.png)](https://github.com/mehta-lab/VisCy/assets/67518483/d53a81eb-eb37-44f3-b522-8bd7bddc7755)
# (Click on image to play video)
#
# %% [markdown] tags=[]
# ### Goals
# #### Part 1: Train a virtual staining model
#
# - Explore OME-Zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html)
# and the high-content-screen (HCS) format.
# - Use our `viscy.data.HCSDataloader()` dataloader and explore the 3 channel (phase, fluoresecence nuclei and cell membrane)
# A549 cell dataset.
# - Implement data augmentations [MONAI](https://monai.io/) to train a robust model to imaging parameters and conditions.
# - Use tensorboard to log the augmentations, training and validation losses and batches
# - Start the training of the UNeXt2 model to predict nuclei and membrane from phase images.
#
# #### Part 2:Evaluate the model to translate phase into fluorescence.
# - Compare the performance of your trained model with the _VSCyto2D_ pre-trained model.
# - Evaluate the model using pixel-level and instance-level metrics.
#
# #### Part 3: Visualize the image transforms learned by the model and explore the model's regime of validity
# - Visualize the first 3 principal componets mapped to a color space in each encoder and decoder block.
# - Explore the model's regime of validity by applying blurring and scaling transforms to the input phase image.
#
# #### For more information:
# Checkout [VisCy](https://github.com/mehta-lab/VisCy),
# our deep learning pipeline for training and deploying computer vision models
# for image-based phenotyping including the robust virtual staining of landmark organelles.
#
# VisCy exploits recent advances in data and metadata formats
# ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks,
# [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/).
# ### References
# - [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf)
# - [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502)
# %% [markdown] tags=[]
# <div class="alert alert-success">
# The exercise is organized in 3 parts:
# <ul>
# <li><b>Part 1</b> - Train a virtual staining model using iohub (I/O library), VisCy dataloaders, and tensorboard</li>
# <li><b>Part 2</b> - Evaluate the model to translate phase into fluorescence.</li>
# <li><b>Part 3</b> - Visualize the image transforms learned by the model and explore the model's regime of validity.</li>
# </ul>
# </div>
# %% [markdown] tags=[]
# <div class="alert alert-danger">
# Set your python kernel to <span style="color:black;">06_image_translation</span>
# </div>
# %% [markdown]
# # Part 1: Log training data to tensorboard, start training a model.
# ---------
# Learning goals:
# - Load the OME-zarr dataset and examine the channels (A549).
# - Configure and understand the data loader.
# - Log some patches to tensorboard.
# - Initialize a 2D UNeXt2 model for virtual staining of nuclei and membrane from phase.
# - Start training the model to predict nuclei and membrane from phase.
# %% Imports
import os
from glob import glob
from pathlib import Path
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchview
import torchvision
from cellpose import models
from iohub import open_ome_zarr
from iohub.reader import print_info
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import TensorBoardLogger
from natsort import natsorted
from numpy.typing import ArrayLike
from skimage import metrics # for metrics.
# pytorch lightning wrapper for Tensorboard.
from skimage.color import label2rgb
from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard
from torchmetrics.functional import accuracy, dice, jaccard_index
from tqdm import tqdm
# HCSDataModule makes it easy to load data during training.
from viscy.data.hcs import HCSDataModule
from viscy.evaluation.evaluation_metrics import mean_average_precision
# Trainer class and UNet.
from viscy.light.engine import MixedLoss, VSUNet
from viscy.light.trainer import VSTrainer
# training augmentations
from viscy.transforms import (NormalizeSampled, RandAdjustContrastd,
RandAffined, RandGaussianNoised,
RandGaussianSmoothd, RandScaleIntensityd,
RandWeightedCropd)
# %%
# seed random number generators for reproducibility.
seed_everything(42, workers=True)
# Paths to data and log directory
top_dir = Path(
"/mnt/efs/dlmbl/share/"
) # If this fails, make sure this to point to your data directory in the shared mounting point inside /dlmbl/data
# Path to the training data
data_path = (
top_dir / "06_image_translation/training/a549_hoechst_cellmask_train_val.zarr"
)
# Path where we will save our training logs
training_top_dir = Path(f"{os.getcwd()}/data/")
# Create top_training_dir directory if needed, and launch tensorboard
training_top_dir.mkdir(parents=True, exist_ok=True)
log_dir = training_top_dir / "06_image_translation/logs/"
# Create log directory if needed, and launch tensorboard
log_dir.mkdir(parents=True, exist_ok=True)
if not data_path.exists():
raise FileNotFoundError(
f"Data not found at {data_path}. Please check the top_dir and data_path variables."
)
# %% [markdown] tags=[]
# The next cell starts tensorboard.
# <div class="alert alert-warning">
# If you launched jupyter lab from ssh terminal, add <code>--host <your-server-name></code> to the tensorboard command below. <code><your-server-name></code> is the address of your compute node that ends in amazonaws.com.
# </div>
# %% tags=[]
# Imports and paths
# Function to find an available port
def find_free_port():
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
# Launch TensorBoard on the browser
def launch_tensorboard(log_dir):
import subprocess
port = find_free_port()
tensorboard_cmd = f"tensorboard --logdir={log_dir} --port={port}"
process = subprocess.Popen(tensorboard_cmd, shell=True)
print(
f"TensorBoard started at http://localhost:{port}. \n"
"If you are using VSCode remote session, forward the port using the PORTS tab next to TERMINAL."
)
return process
# Launch tensorboard and click on the link to view the logs.
tensorboard_process = launch_tensorboard(log_dir)
# %% [markdown] tags = []
# <div class="alert alert-warning">
# If you are using VSCode and a remote server, you will need to forward the port to view the tensorboard. <br>
# Take note of the port number was assigned in the previous cell.(i.e <code> http://localhost:{port_number_assigned}</code>) <br>
# Locate the your VSCode terminal and select the <code>Ports</code> tab <br>
# <ul>
# <li>Add a new port with the <code>port_number_assigned</code>
# </ul>
# Click on the link to view the tensorboard and it should open in your browser.
# </div>
# %% [markdown] tags=[]
# ## Load OME-Zarr Dataset
# There should be 34 FOVs in the dataset.
#
# Each FOV consists of 3 channels of 2048x2048 images,
# saved in the [High-Content Screening (HCS) layout](https://ngff.openmicroscopy.org/latest/#hcs-layout)
# specified by the Open Microscopy Environment Next Generation File Format
# (OME-NGFF).
#
# The 3 channels correspond to the QPI, nuclei, and cell membrane. The nuclei were stained with DAPI and the cell membrane with Cellmask.
#
# - The layout on the disk is: `row/col/field/pyramid_level/timepoint/channel/z/y/x.`
# - These datasets only have 1 level in the pyramid (highest resolution) which is '0'.
# %% [markdown] tags=[]
# <div class="alert alert-warning">
# You can inspect the tree structure by using your terminal:
# <code> iohub info -v "path-to-ome-zarr" </code>
# <br>
# More info on the CLI:
# <code>iohub info --help </code> to see the help menu.
# </div>
# %%
# This is the python function called by `iohub info` CLI command
print_info(data_path, verbose=True)
# Open and inspect the dataset.
dataset = open_ome_zarr(data_path)
# %% [markdown] tags=[]
# <div class="alert alert-info">
#
# ### Task 1.1
# Look at a couple different fields of view (FOVs) by changing the `field` variable.
# Check the cell density, the cell morphologies, and fluorescence signal.
# HINT: look at the HCS Plate format to see what are your options.
# </div>
# %% tags=[]
# Use the field and pyramid_level below to visualize data.
row = 0
col = 0
field = 9 # TODO: Change this to explore data.
# NOTE: this dataset only has one level
pyaramid_level = 0
# `channel_names` is the metadata that is stored with data according to the OME-NGFF spec.
n_channels = len(dataset.channel_names)
image = dataset[f"{row}/{col}/{field}/{pyaramid_level}"].numpy()
print(f"data shape: {image.shape}, FOV: {field}, pyramid level: {pyaramid_level}")
figure, axes = plt.subplots(1, n_channels, figsize=(9, 3))
for i in range(n_channels):
for i in range(n_channels):
channel_image = image[0, i, 0]
# Adjust contrast to 0.5th and 99.5th percentile of pixel values.
p_low, p_high = np.percentile(channel_image, (0.5, 99.5))
channel_image = np.clip(channel_image, p_low, p_high)
axes[i].imshow(channel_image, cmap="gray")
axes[i].axis("off")
axes[i].set_title(dataset.channel_names[i])
plt.tight_layout()
# %% [markdown] tags=[]
# ## Explore the effects of augmentation on batch.
# VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout.
# The dataloader in `HCSDataModule` returns a batch of samples. A `batch` is a list of dictionaries. The length of the list is equal to the batch size. Each dictionary consists of following key-value pairs.
# - `source`: the input image, a tensor of size `(1, 1, Y, X)`
# - `target`: the target image, a tensor of size `(2, 1, Y, X)`
# - `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample.
# %% [markdown] tags=[]
# <div class="alert alert-info">
#
# ### Task 1.2
# - Run the next cell to setup a logger for your augmentations.
# - Setup the `HCSDataloader()` in for training.
# - Configure the dataloader for the `"UNeXt2_2D"`
# - Configure the dataloader for the phase (source) to fluorescence cell nuclei and membrane (targets) regression task.
# - Configure the dataloader for training. Hint: use the `HCSDataloader.setup()`
# - Open your tensorboard and look at the `IMAGES tab`.
#
# Note: If tensorboard is not showing images or the plots, try refreshing and using the "Images" tab.
# </div>
# %%
# Define a function to write a batch to tensorboard log.
def log_batch_tensorboard(batch, batchno, writer, card_name):
"""
Logs a batch of images to TensorBoard.
Args:
batch (dict): A dictionary containing the batch of images to be logged.
writer (SummaryWriter): A TensorBoard SummaryWriter object.
card_name (str): The name of the card to be displayed in TensorBoard.
Returns:
None
"""
batch_phase = batch["source"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.
batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(
1
) # batch_size x 1 x Y x X tensor.
batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(
1
) # batch_size x 1 x Y x X tensor.
p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))
batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)
p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))
batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)
p1, p99 = np.percentile(batch_phase, (0.1, 99.9))
batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)
[N, C, H, W] = batch_phase.shape
interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype)
interleaved_images[0::3, :] = batch_phase
interleaved_images[1::3, :] = batch_nuclei
interleaved_images[2::3, :] = batch_membrane
grid = torchvision.utils.make_grid(interleaved_images, nrow=3)
# add the grid to tensorboard
writer.add_image(card_name, grid, batchno)
# Define a function to visualize a batch on jupyter, in case tensorboard is finicky
def log_batch_jupyter(batch):
"""
Logs a batch of images on jupyter using ipywidget.
Args:
batch (dict): A dictionary containing the batch of images to be logged.
Returns:
None
"""
batch_phase = batch["source"][:, :, 0, :, :] # batch_size x z_size x Y x X tensor.
batch_size = batch_phase.shape[0]
batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(
1
) # batch_size x 1 x Y x X tensor.
batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(
1
) # batch_size x 1 x Y x X tensor.
p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))
batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)
p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))
batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)
p1, p99 = np.percentile(batch_phase, (0.1, 99.9))
batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)
n_channels = batch["target"].shape[1] + batch["source"].shape[1]
plt.figure()
fig, axes = plt.subplots(
batch_size, n_channels, figsize=(n_channels * 2, batch_size * 2)
)
[N, C, H, W] = batch_phase.shape
for sample_id in range(batch_size):
axes[sample_id, 0].imshow(batch_phase[sample_id, 0])
axes[sample_id, 1].imshow(batch_nuclei[sample_id, 0])
axes[sample_id, 2].imshow(batch_membrane[sample_id, 0])
for i in range(n_channels):
axes[sample_id, i].axis("off")
axes[sample_id, i].set_title(dataset.channel_names[i])
plt.tight_layout()
plt.show()
# %% tags=["task"]
# Initialize the data module.
BATCH_SIZE = 4
# 4 is a perfectly reasonable batch size
# (batch size does not have to be a power of 2)
# See: https://sebastianraschka.com/blog/2022/batch-size-2.html
# #######################
# ##### TODO ########
# #######################
# HINT: Run dataset.channel_names
source_channel = ["TODO"]
target_channel = ["TODO", "TODO"]
# #######################
# ##### TODO ########
# #######################
data_module = HCSDataModule(
data_path,
z_window_size=1,
architecture= #TODO# 2D UNeXt2 architecture
source_channel=source_channel,
target_channel=target_channel,
split_ratio=0.8,
batch_size=BATCH_SIZE,
num_workers=8,
yx_patch_size=(256, 256), # larger patch size makes it easy to see augmentations.
augmentations=[], # Turn off augmentation for now.
normalizations=[], # Turn off normalization for now.
)
# #######################
# ##### TODO ########
# #######################
# Setup the data_module to fit. HINT: data_module.setup()
# Evaluate the data module
print(
f"Samples in training set: {len(data_module.train_dataset)}, "
f"samples in validation set:{len(data_module.val_dataset)}"
)
train_dataloader = data_module.train_dataloader()
# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.
writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
# Draw a batch and write to tensorboard.
batch = next(iter(train_dataloader))
log_batch_tensorboard(batch, 0, writer, "augmentation/none")
writer.close()
# %% tags=["solution"]
# #######################
# ##### SOLUTION ########
# #######################
BATCH_SIZE = 4
# 4 is a perfectly reasonable batch size
# (batch size does not have to be a power of 2)
# See: https://sebastianraschka.com/blog/2022/batch-size-2.html
source_channel = ["Phase3D"]
target_channel = ["Nucl", "Mem"]
data_module = HCSDataModule(
data_path,
z_window_size=1,
architecture="UNeXt2_2D",
source_channel=source_channel,
target_channel=target_channel,
split_ratio=0.8,
batch_size=BATCH_SIZE,
num_workers=8,
yx_patch_size=(256, 256), # larger patch size makes it easy to see augmentations.
augmentations=[], # Turn off augmentation for now.
normalizations=[], # Turn off normalization for now.
)
# Setup the data_module to fit. HINT: data_module.setup()
data_module.setup("fit")
# Evaluate the data module
print(
f"Samples in training set: {len(data_module.train_dataset)}, "
f"samples in validation set:{len(data_module.val_dataset)}"
)
train_dataloader = data_module.train_dataloader()
# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.
writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
# Draw a batch and write to tensorboard.
batch = next(iter(train_dataloader))
log_batch_tensorboard(batch, 0, writer, "augmentation/none")
writer.close()
# %% [markdown] tags=[]
# <div class="alert alert-warning">
#
# ### Questions
# 1. What are the two channels in the target image?
# 2. How many samples are in the training and validation set? What determined that split?
#
# Note: If tensorboard is not showing images, try refreshing and using the "Images" tab.
# </div>
# %% [markdown] tags=[]
# If your tensorboard is causing issues, you can visualize directly on Jupyter /VSCode
# %%
# Visualize in Jupyter
log_batch_jupyter(batch)
# %% [markdown] tags=[]
# <div class="alert alert-warning">
# <h3> Question for Task 1.3 </h3>
# 1. How do they make the model more robust to imaging parameters or conditions
# without having to acquire data for every possible condition? <br>
# </div>
# %% [markdown] tags=[]
# <div class="alert alert-info">
#
# ### Task 1.3
# Add the following augmentations:
# - Add augmentations to rotate about $\pi$ around z-axis, 30% scale in y,x,
# shearing of 10% and no padding with zeros with a probablity of 80%.
# - Add a Gaussian noise with a mean of 0.0 and standard deviation of 0.3 with a probability of 50%.
#
# HINT: `RandAffined()` and `RandGaussianNoised()` are from
# `viscy.transforms` [here](https://github.com/mehta-lab/VisCy/blob/main/viscy/transforms.py). You can look at the docs by running `RandAffined?`.<br><br>
# *Note these are MONAI transforms that have been redefined for VisCy.*
# [Compare your choice of augmentations by dowloading the pretrained models and config files](https://github.com/mehta-lab/VisCy/releases/download/v0.1.0/VisCy-0.1.0-VS-models.zip).
# </div>
# %% tags=["task"]
# Here we turn on data augmentation and rerun setup
# #######################
# ##### TODO ########
# #######################
# HINT: Run dataset.channel_names
source_channel = ["TODO"]
target_channel = ["TODO", "TODO"]
augmentations = [
RandWeightedCropd(
keys=source_channel + target_channel,
spatial_size=(1, 384, 384),
num_samples=2,
w_key=target_channel[0],
),
# #######################
# ##### TODO ########
# #######################
## TODO: Add Random Affine Transorms
## Write code below
# #######################
RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)),
RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5),
# #######################
# ##### TODO ########
# #######################
## TODO: Add Random Gaussian Noise
## Write code below
# #######################
RandGaussianSmoothd(
keys=source_channel,
sigma_x=(0.25, 0.75),
sigma_y=(0.25, 0.75),
sigma_z=(0.0, 0.0),
prob=0.5,
),
]
normalizations = [
NormalizeSampled(
keys=source_channel,
level="fov_statistics",
subtrahend="mean",
divisor="std",
),
NormalizeSampled(
keys=target_channel,
level="fov_statistics",
subtrahend="median",
divisor="iqr",
),
]
data_module.augmentations = augmentations
data_module.setup("fit")
# get the new data loader with augmentation turned on
augmented_train_dataloader = data_module.train_dataloader()
# Draw batches and write to tensorboard
writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
augmented_batch = next(iter(augmented_train_dataloader))
log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some")
writer.close()
# %% tags=["solution"]
# #######################
# ##### SOLUTION ########
# #######################
source_channel = ["Phase3D"]
target_channel = ["Nucl", "Mem"]
augmentations = [
RandWeightedCropd(
keys=source_channel + target_channel,
spatial_size=(1, 384, 384),
num_samples=2,
w_key=target_channel[0],
),
RandAffined(
keys=source_channel + target_channel,
rotate_range=[3.14, 0.0, 0.0],
scale_range=[0.0, 0.3, 0.3],
prob=0.8,
padding_mode="zeros",
shear_range=[0.0, 0.01, 0.01],
),
RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)),
RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5),
RandGaussianNoised(keys=source_channel, prob=0.5, mean=0.0, std=0.3),
RandGaussianSmoothd(
keys=source_channel,
sigma_x=(0.25, 0.75),
sigma_y=(0.25, 0.75),
sigma_z=(0.0, 0.0),
prob=0.5,
),
]
normalizations = [
NormalizeSampled(
keys=source_channel,
level="fov_statistics",
subtrahend="mean",
divisor="std",
),
NormalizeSampled(
keys=target_channel,
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
]
data_module.augmentations = augmentations
# Setup the data_module to fit. HINT: data_module.setup()
data_module.setup("fit")
# get the new data loader with augmentation turned on
augmented_train_dataloader = data_module.train_dataloader()
# Draw batches and write to tensorboard
writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
augmented_batch = next(iter(augmented_train_dataloader))
log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some")
writer.close()
# %% [markdown] tags=[]
# <div class="alert alert-warning">
# <h3> Question for Task 1.3 </h3>
# 1. Look at your tensorboard. Can you tell the agumentations were applied to the sample batch? Compare the batch with and without augmentations. <br>
# 2. Are these augmentations good enough? What else would you add?
# </div>
# %% [markdown]
# Visualize directly on Jupyter
# %%
log_batch_jupyter(augmented_batch)
# %% [markdown] tags=[]
# ## Train a 2D U-Net model to predict nuclei and membrane from phase.
# ### Constructing a 2D UNeXt2 using VisCy
# %% [markdown]
# <div class="alert alert-info">
#
# ### Task 1.5
# - Run the next cell to instantiate the `UNeXt2_2D` model
# - Configure the network for the phase (source) to fluorescence cell nuclei and membrane (targets) regression task.
# - Call the VSUNet with the `"UNeXt2_2D"` architecture.
# - Run the next cells to instantiate data module and trainer.
# - Add the source channel name and the target channel names
# - Start the training <br>
#
# <b> Note </b> <br>
# See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) to learn more about the configuration.
# </div>
# %% tags=["task"]
# Create a 2D UNet.
GPU_ID = 0
BATCH_SIZE = 16
YX_PATCH_SIZE = (256, 256)
# #######################
# ##### TODO ########
# #######################
# Dictionary that specifies key parameters of the model.
phase2fluor_config = dict(
in_channels= #TODO how many input channels are we feeding Hint: int?,
out_channels= #TODO how many output channels are we solving for? Hint: int,
encoder_blocks=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
decoder_conv_blocks=2,
stem_kernel_size=(1, 2, 2),
in_stack_depth= #TODO was this a 2D or 3D input? HINT: int,
pretraining=False,
)
# #######################
# ##### TODO ########
# #######################
phase2fluor_model = VSUNet(
architecture= #TODO# 2D UNeXt2 architecture
model_config=phase2fluor_config.copy(),
loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5),
schedule="WarmupCosine",
lr=6e-4,
log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard.
freeze_encoder=False,
)
# #######################
# ##### TODO ########
# #######################
# HINT: Run dataset.channel_names
source_channel = ["TODO"]
target_channel = ["TODO", "TODO"]
# Setup the data module.
phase2fluor_2D_data = HCSDataModule(
data_path,
architecture=#TODO# 2D UNeXt2 architecture. Same string as above.
source_channel=source_channel,
target_channel=target_channel,
z_window_size=1,
split_ratio=0.8,
batch_size=BATCH_SIZE,
num_workers=8,
yx_patch_size=YX_PATCH_SIZE,
augmentations=augmentations,
normalizations=normalizations,
)
phase2fluor_2D_data.setup("fit")
# fast_dev_run runs a single batch of data through the model to check for errors.
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], precision='16-mixed' ,fast_dev_run=True)
# trainer class takes the model and the data module as inputs.
trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)
# %% tags=["solution"]
# Here we are creating a 2D UNet.
GPU_ID = 0
BATCH_SIZE = 16
YX_PATCH_SIZE = (256, 256)
# Dictionary that specifies key parameters of the model.
# #######################
# ##### SOLUTION ########
# #######################
phase2fluor_config = dict(
in_channels=1,
out_channels=2,
encoder_blocks=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
decoder_conv_blocks=2,
stem_kernel_size=(1, 2, 2),
in_stack_depth=1,
pretraining=False,
)
phase2fluor_model = VSUNet(
architecture="UNeXt2_2D", # 2D UNeXt2 architecture
model_config=phase2fluor_config.copy(),
loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5),
schedule="WarmupCosine",
lr=6e-4,
log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard.
freeze_encoder=False,
)
# ### Instantiate data module and trainer, test that we are setup to launch training.
# #######################
# ##### SOLUTION ########
# #######################
# Selecting the source and target channel names from the dataset.
source_channel = ["Phase3D"]
target_channel = ["Nucl", "Mem"]
# Setup the data module.
phase2fluor_2D_data = HCSDataModule(
data_path,
architecture="UNeXt2_2D",
source_channel=source_channel,
target_channel=target_channel,
z_window_size=1,
split_ratio=0.8,
batch_size=BATCH_SIZE,
num_workers=8,
yx_patch_size=YX_PATCH_SIZE,
augmentations=augmentations,
normalizations=normalizations,
)
# #######################
# ##### SOLUTION ########
# #######################
phase2fluor_2D_data.setup("fit")
# fast_dev_run runs a single batch of data through the model to check for errors.
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID],precision='16-mixed', fast_dev_run=True)
# trainer class takes the model and the data module as inputs.
trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)
# %% [markdown] tags=[]
# ## View model graph.
#
# PyTorch uses dynamic graphs under the hood.
# The graphs are constructed on the fly.
# This is in contrast to TensorFlow,
# where the graph is constructed before the training loop and remains static.
# In other words, the graph of the network can change with every forward pass.
# Therefore, we need to supply an input tensor to construct the graph.
# The input tensor can be a random tensor of the correct shape and type.
# We can also supply a real image from the dataset.
# The latter is more useful for debugging.
# %% [markdown]
# <div class="alert alert-info">
#
# ### Task 1.5
# Run the next cell to generate a graph representation of the model architecture.
# </div>
# %%
# visualize graph of phase2fluor model as image.
model_graph_phase2fluor = torchview.draw_graph(
phase2fluor_model,
phase2fluor_2D_data.train_dataset[0]["source"][0].unsqueeze(dim=0),
roll=True,
depth=3, # adjust depth to zoom in.
device="cpu",
# expand_nested=True,
)
# Print the image of the model.
model_graph_phase2fluor.visual_graph
# %% [markdown] tags=[]
# <div class="alert alert-warning">
#
# ### Question:
# Can you recognize the UNet structure and skip connections in this graph visualization?
# </div>
# %% [markdown]
# <div class="alert alert-info">
# <h3> Task 1.6 </h3>
# Start training by running the following cell. Check the new logs on the tensorboard.
# </div>
# %%
# Check if GPU is available
# You can check by typing `nvidia-smi`
GPU_ID = 0
n_samples = len(phase2fluor_2D_data.train_dataset)
steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch.
n_epochs = 80 # Set this to 80-100 or the number of epochs you want to train for.
trainer = VSTrainer(
accelerator="gpu",
devices=[GPU_ID],
max_epochs=n_epochs,
precision='16-mixed',
log_every_n_steps=steps_per_epoch // 2,
# log losses and image samples 2 times per epoch.
logger=TensorBoardLogger(
save_dir=log_dir,
# lightning trainer transparently saves logs and model checkpoints in this directory.
name="phase2fluor",
log_graph=True,
),
)
# Launch training and check that loss and images are being logged on tensorboard.
trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)
# Move the model to the GPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
phase2fluor_model.to(device)
# %% [markdown] tags=[]
# <div class="alert alert-success">
# <h2> Checkpoint 1 </h2>
# While your model is training, let's think about the following questions:<br>
# <ul>
# <li>What is the information content of each channel in the dataset?</li>
# <li>How would you use image translation models?</li>
# <li>What can you try to improve the performance of each model?</li>
# </ul>
# Now the training has started,
# we can come back after a while and evaluate the performance!
# </div>
# %% [markdown] tags=[]
# # Part 2: Assess your trained model
# Now we will look at some metrics of performance of previous model.
# We typically evaluate the model performance on a held out test data.
# We will use the following metrics to evaluate the accuracy of regression of the model:
# - [Person Correlation](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient).
# - [Structural similarity](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM).
# You should also look at the validation samples on tensorboard
# (hint: the experimental data in nuclei channel is imperfect.)
# %% [markdown]
# <div class="alert alert-info">
# <h3> Task 2.1 Define metrics </h3>
# For each of the above metrics, write a brief definition of what they are and what they mean
# for this image translation task. Use your favorite search engine and/or resources.
# </div>
# %% [markdown] tags=[]
# ```
# #######################
# ##### Todo ############
# #######################
#
# ```
#
# - Pearson Correlation:
#
# - Structural similarity:
# %% [markdown] tags=[]
# ### Let's compute metrics directly and plot below.
# %% [markdown] tags=[]
# <div class="alert alert-danger">
# If you weren't able to train or training didn't complete please run the following lines to load the latest checkpoint <br>
#
# ```python
# phase2fluor_model_ckpt = natsorted(glob(
# str(top_dir / "06_image_translation/logs/phase2fluor/version*/checkpoints/*.ckpt")
# ))[-1]
#```
#<br>
# NOTE: if their model didn't go past epoch 5, lost their checkpoint, or didnt train anything.
# Run the following:
#
#```python
#phase2fluor_model_ckpt = natsorted(glob(
# str(top_dir/"06_image_translation/backup/phase2fluor/version_0/checkpoints/*.ckpt")
#))[-1]
#```
#```python
#phase2fluor_config = dict(
# in_channels=1,
# out_channels=2,
# encoder_blocks=[3, 3, 9, 3],
# dims=[96, 192, 384, 768],
# decoder_conv_blocks=2,
# stem_kernel_size=(1, 2, 2),
# in_stack_depth=1,
# pretraining=False,
# )
# Load the model checkpoint
# phase2fluor_model = VSUNet.load_from_checkpoint(
# phase2fluor_model_ckpt,
# architecture="UNeXt2_2D",
# model_config = phase2fluor_config,
# accelerator='gpu'
# )
#````
# </div>
# %%
# Setup the test data module.
test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr"
source_channel = ["Phase3D"]
target_channel = ["Nucl", "Mem"]
test_data = HCSDataModule(
test_data_path,
source_channel=source_channel,
target_channel=target_channel,
z_window_size=1,
batch_size=1,
num_workers=8,
architecture="UNeXt2",
)
test_data.setup("test")
test_metrics = pd.DataFrame(
columns=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"]
)
# %%
# Compute metrics directly and plot here.
def normalize_fov(input:ArrayLike):
"Normalizing the fov with zero mean and unit variance"
mean = np.mean(input)
std = np.std(input)
return (input - mean) / std
for i, sample in enumerate(tqdm(test_data.test_dataloader(), desc="Computing metrics per sample")):
phase_image = sample["source"].to(phase2fluor_model.device)
with torch.inference_mode(): # turn off gradient computation.
predicted_image = phase2fluor_model(phase_image)
target_image = (