Skip to content

Commit

Permalink
Merge pull request #13 from mantidproject/37527_deploy_cnn
Browse files Browse the repository at this point in the history
37527 scripts to run CNN Bragg peaks finding for WISH
  • Loading branch information
RichardWaiteSTFC authored Aug 7, 2024
2 parents 1d49c2f + 1be79bc commit f90bafb
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 64 deletions.
Empty file.
67 changes: 67 additions & 0 deletions diffraction/WISH/bragg-detect/bragg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from mantid.simpleapi import CreatePeaksWorkspace, AddPeak
import numpy as np

def make_3d_array(ws, ntubes=1520, npix_per_tube=128, ntubes_per_bank=152, nmonitors=5):
"""
Extract Ydata from WISH MatrixWorkspace and shape into 3d array
:param ws: MatrixWorkspace of WISH data with xunit wavelength
:param ntubes: number of tubes in instrument (WISH)
:param npix_per_tube: number of detector pixels in each tube
:param ntubes_per_bank: number of tubes per bank
:param nmonitors: number of monitor spectra (assumed first spectra in ws)
:return: 3d numpy array ntubes x npix per tube x nbins
"""
y = ws.extractY()[nmonitors:,:] # exclude monitors - alternatively load with monitors separate?
nbins = ws.blocksize() # 4451
y = np.reshape(y, (ntubes, npix_per_tube, nbins))[:,::-1,:] # ntubes x npix x nbins (note flipped pix along tube)
# reverse order of tubes in each bank
nbanks = ntubes//ntubes_per_bank
for ibank in range(0, nbanks):
istart = ibank*ntubes_per_bank
iend = (ibank+1)*ntubes_per_bank
y[istart:iend,:,:] = y[istart:iend,:,:][::-1,:,:]
return y


def createPeaksWorkspaceFromIndices(ws, peak_wsname, indices, data):
"""
Create a PeaksWorkspace using indices of peaks found in 3d array. WISH has 4 detectors so put peak in central pixel.
Could add peak using avg QLab for peaks at that lambda in all detectors but peak placed in nearest detector anyway
for more details see https://github.com/mantidproject/mantid/issues/31944
:param ws: MatrixWorkspace of WISH data with xunit wavelength
:param peak_wsname: Output name of peaks workspace created
:param indices: indices of peaks found in 3d array
:param data: 3d array of data
:return: PeaksWorkspace
"""
ispec = findSpectrumIndex(indices, *data.shape[0:2])
peaks = CreatePeaksWorkspace(InstrumentWorkspace=ws, NumberOfPeaks=0,
OutputWorkspace=peak_wsname, EnableLogging=False)
for ipk in range(len(ispec)):
wavelength = ws.readX(ispec[ipk])[indices[ipk,2]]
# four detectors per spectrum so use one of the central ones
detIDs = ws.getSpectrum(ispec[ipk]).getDetectorIDs()
idet = (len(detIDs)-1)//2 # pick central pixel
AddPeak(PeaksWorkspace=peaks, RunWorkspace=ws, TOF=wavelength, DetectorID=detIDs[idet],
BinCount=data[indices[ipk,0], indices[ipk,1], indices[ipk,2]], EnableLogging=False)
return peaks


def findSpectrumIndex(indices, ntubes=1520, npix_per_tube=128, ntubes_per_bank=152, nmonitors=5):
"""
:param indices: indices of found peaks in 3d array
:param ntubes: number of tubes in instrument (WISH)
:param npix_per_tube: number of detector pixels in each tube
:param ntubes_per_bank: number of tubes per bank
:param nmonitors: number of monitor spectra (assumed first spectra in ws)
:return: list of spectrum indindices of workspace corresponding to indices of found peaks in 3d array
"""
# find bank and then reverse order of tubes in bank
ibank = np.floor(indices[:,0]/ntubes_per_bank)
itube = ibank*ntubes_per_bank + ((ntubes_per_bank-1) - indices[:,0] % ntubes_per_bank)
# flip tube
ipix = (npix_per_tube-1) - indices[:,1]
# get spectrum index
specIndex = np.ravel_multi_index((itube.astype(int), ipix),
dims=(ntubes, npix_per_tube), order='C') + nmonitors
return specIndex.tolist() # so elements are of type int not numpy.int32
111 changes: 111 additions & 0 deletions diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torch as tc
import warnings
from cnn.WISHDataSets import WISHWorkspaceDataSet
import numpy as np
from bragg_utils import createPeaksWorkspaceFromIndices
from tqdm import tqdm
from Diffraction.single_crystal.base_sx import BaseSX
import time

class BraggDetectCNN:
"""
Detects Bragg's peaks from a workspace using a pre trained deep learning model created using Faster R-CNN model with a ResNet-50-FPN backbone.
Example usage is as shown below.
#1) Set the path of the .pt file that contains the weights of the pre trained FasterRCNN model
cnn_weights_path = r""
# 2) Create a peaks workspace containing bragg peaks detected with a confidence greater than conf_threshold
cnn_bragg_peaks_detector = BraggDetectCNN(model_weights_path=cnn_weights_path, batch_size=64, workers=0, iou_threshold=0.001)
cnn_bragg_peaks_detector.find_bragg_peaks(workspace="WISH00042730", conf_threshold=0.0, q_tol=0.05)
"""

def __init__(self, model_weights_path, batch_size=64, workers=0, iou_threshold=0.001):
"""
:param model_weights_path: Path to the .pt file containing the weights of the pre trained CNN model
:param batch_size: Batch size to be used when loading tube data for inferencing.
:param workers: Number of loader worker processes to do multi-process data loading. workers=0 means data loading in main process
:param iou_threshold: IOU(Intersection Over Union) threshold to filter out overlapping bounding boxes for detected peaks
"""
self.model_weights_path = model_weights_path
self.device = self._select_device()
self.model = self._load_cnn_model_from_weights(self.model_weights_path)
self.batch_size = batch_size
self.workers = workers
self.iou_threshold = iou_threshold


def find_bragg_peaks(self, workspace, output_ws_name="CNN_Peaks", conf_threshold=0.0, q_tol=0.05):
"""
Find bragg peaks using the pre trained FasterRCNN model and create a peaks workspace
:param workspace: Workspace name or the object of Workspace from WISH, ex: "WISH0042730"
:param output_ws_name: Name of the peaks workspace
:param conf_threshold: Confidence threshold to filter peaks inferred from RCNN
:param q_tol: qlab tolerance to remove duplicate peaks
"""
start_time = time.time()
data_set, predicted_indices = self._do_cnn_inferencing(workspace)
filtered_indices = predicted_indices[predicted_indices[:, -1] > conf_threshold]
filtered_indices_rounded = np.round(filtered_indices[:, :-1]).astype(int)
peaksws = createPeaksWorkspaceFromIndices(data_set.get_workspace(), output_ws_name, filtered_indices_rounded, data_set.get_ws_as_3d_array())
for ipk, pk in enumerate(peaksws):
pk.setIntensity(filtered_indices[ipk, -1])

#Filter duplicates by qlab
BaseSX.remove_duplicate_peaks_by_qlab(peaksws, q_tol)
data_set.delete_rebunched_ws()
print(f"Bragg peaks finding from FasterRCNN model is completed in {time.time()-start_time} seconds!")


def _do_cnn_inferencing(self, workspace):
data_set = WISHWorkspaceDataSet(workspace)
data_loader = tc.utils.data.DataLoader(data_set, batch_size=self.batch_size, shuffle=False, num_workers=self.workers)
self.model.eval()
predicted_indices_with_score = []
for batch_idx, img_batch in enumerate(tqdm(data_loader, desc="Processing batches of tubes")):
for img_idx, img in enumerate(img_batch):
tube_idx = batch_idx * data_loader.batch_size + img_idx
with tc.no_grad():
prediction = self.model([img.to(self.device)])[0]
nms_prediction = self._apply_nms(prediction, self.iou_threshold)
for box, score in zip(nms_prediction['boxes'], nms_prediction['scores']):
tof = (box[0]+box[2])/2
tube_res = (box[1]+box[3])/2
predicted_indices_with_score.append([tube_idx, tube_res.cpu(), tof.cpu(), score.cpu()])
return data_set, np.array(predicted_indices_with_score)


def _apply_nms(self, orig_prediction, iou_thresh):
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
return final_prediction


def _select_device(self):
if tc.cuda.is_available():
print("GPU device is found!")
return tc.device("cuda")
else:
warnings.warn(
"Warning! GPU is not available, the program will run very slow..", RuntimeWarning)
return tc.device("cpu")


def _load_cnn_model_from_weights(self, weights_path):
model = self._get_fasterrcnn_resnet50_fpn(num_classes=2)
model.load_state_dict(tc.load(weights_path, map_location=self.device))
return model.to(self.device)


def _get_fasterrcnn_resnet50_fpn(self, num_classes=2):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the head
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model

23 changes: 23 additions & 0 deletions diffraction/WISH/bragg-detect/cnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Bragg Peaks detection using a Faster RCNN model
================

Inorder to use the pretrained Faster RCNN model inside mantid, below steps are required.

* Install mantid from conda `mamba create -n mantid_cnn -c mantid mantidworkbench`
* Activate the conda environment with `mamba activate mantid_cnn`
* Launch workbench from `workbench` command
* Download the script repository's `scriptrepository\diffraction\WISH` directory as instructed here https://docs.mantidproject.org/nightly/workbench/scriptrepository.html
* Check whether `<local path>\diffraction\WISH` path is available at `Python Script Directories` tab from `File->Manage User Directories`.
* Close the workbench
* From command line, change the directory to the place where the scripts were downloaded ex: `<local path>\diffraction\WISH\bragg-detect\cnn`
* Within the same conda enviroment, install pytorch dependancies by running `pip install -r requirements.txt`
* Install NVIDIA CUDA Deep Neural Network library (cuDNN) by running `conda install -c anaconda cudnn`
* Re-launch workbench from `workbench` command
* Below is an example code snippet to test the code. It will create a peaks workspace with the inferred peaks from the cnn and will do a peak filtering using the q_tol provided using `BaseSX.remove_duplicate_peaks_by_qlab`.
```python
from cnn.BraggDetectCNN import BraggDetectCNN
model_weights = r'path/to/pretrained/fasterrcnn_resnet50_model_weights.pt'
cnn_peaks_detector = BraggDetectCNN(model_weights_path=model_weights, batch_size=64)
cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, q_tol=0.05)
```
* If the above import is not working, check whether the `<local path>\diffraction\WISH` path is listed under `Python Script Directories` tab from `File->Manage User Directories`.
49 changes: 49 additions & 0 deletions diffraction/WISH/bragg-detect/cnn/WISHDataSets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch as tc
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from bragg_utils import make_3d_array
from mantid.simpleapi import Load, Rebunch, Workspace, DeleteWorkspace


class WISHWorkspaceDataSet(tc.utils.data.Dataset):
def __init__(self, workspace):
if isinstance(workspace, Workspace):
if workspace.getAxis(0).getUnit().unitID() != "TOF":
raise RuntimeError("Unit of the X-axis is expected to be TOF")
ws = workspace
ws_name = ws.getName()
elif isinstance(workspace, str):
ws = Load(Filename=workspace, OutputWorkspace=workspace, EnableLogging=False)
ws_name = ws
else:
raise RuntimeError("Invalid workspace type - must be Workspace object or a name of a workspace to Load")

self.rebunched_ws = Rebunch(InputWorkspace=ws, NBunch=3, OutputWorkspace=f"__{ws_name}_cnn_rebunched", EnableLogging=False)
self.ws_3d = make_3d_array(self.rebunched_ws)
print(f"Data set for {workspace} is created with shape{self.ws_3d.shape}")
self.trans = A.Compose([A.pytorch.transforms.ToTensorV2(p=1.0)])

def get_workspace(self):
if self.rebunched_ws is None:
raise RuntimeError("Rebunched workspace is not available!")
return self.rebunched_ws

def delete_rebunched_ws(self):
DeleteWorkspace(Workspace=self.rebunched_ws)
self.rebunched_ws = None

def get_ws_as_3d_array(self):
return self.ws_3d

def __len__(self):
"""
Return number of tubes in the ws
"""
return self.ws_3d.shape[0]

def __getitem__(self, idx):
tube_data = self.ws_3d[idx, ...]
tube_data = np.tile(tube_data[:,:, None], 3).astype(np.float32)
frame = self.trans(image=tube_data)
return frame['image']
Empty file.
6 changes: 6 additions & 0 deletions diffraction/WISH/bragg-detect/cnn/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-f https://download.pytorch.org/whl/cu118
torch
torchvision

albumentations==1.4.0
tqdm==4.66.2
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import h5py
from os import path
from bragg_detect import detect_bragg_peaks

from bragg_utils import make_3d_array, createPeaksWorkspaceFromIndices

# functions to process data to use as ML input

Expand All @@ -22,26 +22,6 @@ def load_raw_run(runno):
ws = ConvertToPointData(InputWorkspace=str(runno), OutputWorkspace=str(runno))
return ws

def make_3d_array(ws, ntubes=1520, npix_per_tube=128, ntubes_per_bank=152, nmonitors=5):
"""
Extract Ydata from WISH MatrixWorkspace and shape into 3d array
:param ws: MatrixWorkspace of WISH data with xunit wavelength
:param ntubes: number of tubes in instrument (WISH)
:param npix_per_tube: number of detector pixels in each tube
:param ntubes_per_bank: number of tubes per bank
:param nmonitors: number of monitor spectra (assumed first spectra in ws)
:return: 3d numpy array ntubes x npix per tube x nbins
"""
y = ws.extractY()[nmonitors:,:] # exclude monitors - alternatively load with monitors separate?
nbins = ws.blocksize() # 4451
y = np.reshape(y, (ntubes, npix_per_tube, nbins))[:,::-1,:] # ntubes x npix x nbins (note flipped pix along tube)
# reverse order of tubes in each bank
nbanks = ntubes//ntubes_per_bank
for ibank in range(0, nbanks):
istart = ibank*ntubes_per_bank
iend = (ibank+1)*ntubes_per_bank
y[istart:iend,:,:] = y[istart:iend,:,:][::-1,:,:]
return y

def save_data(data, filepath):
"""
Expand All @@ -68,49 +48,6 @@ def read_data(filepath):
def save_found_peaks(filepath, indices):
np.savetxt(filepath, indices, fmt='%d')

# functions to process output of ML peak finding

def findSpectrumIndex(indices, ntubes=1520, npix_per_tube=128, ntubes_per_bank=152, nmonitors=5):
"""
:param indices: indices of found peaks in 3d array
:param ntubes: number of tubes in instrument (WISH)
:param npix_per_tube: number of detector pixels in each tube
:param ntubes_per_bank: number of tubes per bank
:param nmonitors: number of monitor spectra (assumed first spectra in ws)
:return: list of spectrum indindices of workspace corresponding to indices of found peaks in 3d array
"""
# find bank and then reverse order of tubes in bank
ibank = np.floor(indices[:,0]/ntubes_per_bank)
itube = ibank*ntubes_per_bank + ((ntubes_per_bank-1) - indices[:,0] % ntubes_per_bank)
# flip tube
ipix = (npix_per_tube-1) - indices[:,1]
# get spectrum index
specIndex = np.ravel_multi_index((itube.astype(int), ipix),
dims=(ntubes, npix_per_tube), order='C') + nmonitors
return specIndex.tolist() # so elements are of type int not numpy.int32

def createPeaksWorkspaceFromIndices(ws, peak_wsname, indices, data):
"""
Create a PeaksWorkspace using indices of peaks found in 3d array. WISH has 4 detectors so put peak in central pixel.
Could add peak using avg QLab for peaks at that lambda in all detectors but peak placed in nearest detector anyway
for more details see https://github.com/mantidproject/mantid/issues/31944
:param ws: MatrixWorkspace of WISH data with xunit wavelength
:param peak_wsname: Output name of peaks workspace created
:param indices: indices of peaks found in 3d array
:param data: 3d array of data
:return: PeaksWorkspace
"""
ispec = findSpectrumIndex(indices, *data.shape[0:2])
peaks = CreatePeaksWorkspace(InstrumentWorkspace=ws, NumberOfPeaks=0,
OutputWorkspace=peak_wsname)
for ipk in range(len(ispec)):
wavelength = ws.readX(ispec[ipk])[indices[ipk,2]]
# four detectors per spectrum so use one of the central ones
detIDs = ws.getSpectrum(ispec[ipk]).getDetectorIDs()
idet = (len(detIDs)-1)//2 # pick central pixel
AddPeak(PeaksWorkspace=peaks, RunWorkspace=ws, TOF=wavelength, DetectorID=detIDs[idet],
BinCount=data[indices[ipk,0], indices[ipk,1], indices[ipk,2]])
return peaks

if __name__ == "__main__" or "mantidqt.widgets.codeeditor.execution":
# 1. reduce data and save 3d numpy array from ML algorithm (bragg-detect)
Expand Down

0 comments on commit f90bafb

Please sign in to comment.