diff --git a/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py b/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py index 61d769f..a10992b 100644 --- a/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py +++ b/diffraction/WISH/bragg-detect/cnn/BraggDetectCNN.py @@ -8,6 +8,14 @@ from tqdm import tqdm from Diffraction.single_crystal.base_sx import BaseSX import time +from enum import Enum +from sklearn.cluster import HDBSCAN +from sklearn.metrics import silhouette_score + +class Clustering(Enum): + QLab = 1 + HDBSCAN = 2 + class BraggDetectCNN: """ @@ -19,7 +27,7 @@ class BraggDetectCNN: # 2) Create a peaks workspace containing bragg peaks detected with a confidence greater than conf_threshold cnn_bragg_peaks_detector = BraggDetectCNN(model_weights_path=cnn_weights_path, batch_size=64, workers=0, iou_threshold=0.001) - cnn_bragg_peaks_detector.find_bragg_peaks(workspace="WISH00042730", conf_threshold=0.0, q_tol=0.05) + cnn_bragg_peaks_detector.find_bragg_peaks(workspace="WISH00042730", conf_threshold=0.0, clustering="QLab", q_tol=0.05) """ def __init__(self, model_weights_path, batch_size=64, workers=0, iou_threshold=0.001): @@ -37,28 +45,83 @@ def __init__(self, model_weights_path, batch_size=64, workers=0, iou_threshold=0 self.iou_threshold = iou_threshold - def find_bragg_peaks(self, workspace, output_ws_name="CNN_Peaks", conf_threshold=0.0, q_tol=0.05): + def find_bragg_peaks(self, workspace, output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering=Clustering.QLab.name, **kwargs): """ Find bragg peaks using the pre trained FasterRCNN model and create a peaks workspace :param workspace: Workspace name or the object of Workspace from WISH, ex: "WISH0042730" :param output_ws_name: Name of the peaks workspace :param conf_threshold: Confidence threshold to filter peaks inferred from RCNN - :param q_tol: qlab tolerance to remove duplicate peaks + :param clustering: name of clustering method(QLab or HDBSCAN). Default is QLab + :param kwargs: variable keyword params for clustering methods """ start_time = time.time() data_set, predicted_indices = self._do_cnn_inferencing(workspace) + filtered_indices = predicted_indices[predicted_indices[:, -1] > conf_threshold] - filtered_indices_rounded = np.round(filtered_indices[:, :-1]).astype(int) - peaksws = createPeaksWorkspaceFromIndices(data_set.get_workspace(), output_ws_name, filtered_indices_rounded, data_set.get_ws_as_3d_array()) + + #Do Clustering + print(f"Starting peak clustering with {clustering} method..") + clustered_peaks = self._do_peak_clustering(filtered_indices, clustering, **kwargs) + cluster_indices_rounded = np.round(clustered_peaks[:, :3]).astype(int) + peaksws = createPeaksWorkspaceFromIndices(data_set.get_workspace(), output_ws_name, cluster_indices_rounded, data_set.get_ws_as_3d_array()) for ipk, pk in enumerate(peaksws): - pk.setIntensity(filtered_indices[ipk, -1]) + pk.setIntensity(clustered_peaks[ipk, -1]) + + if clustering == Clustering.QLab.name: + #Filter peaks by qlab + clustering_params = {"q_tol": 0.05 } + clustering_params.update(kwargs) + BaseSX.remove_duplicate_peaks_by_qlab(peaksws, **clustering_params) + + print(f"Number of peaks after clustering is = {len(peaksws)}") - #Filter duplicates by qlab - BaseSX.remove_duplicate_peaks_by_qlab(peaksws, q_tol) data_set.delete_rebunched_ws() - print(f"Bragg peaks finding from FasterRCNN model is completed in {time.time()-start_time} seconds!") + print(f"Bragg peaks finding from FasterRCNN model is completed in {time.time()-start_time:.2f} seconds!") + + + def _do_peak_clustering(self, detected_peaks, clustering, **kwargs): + print(f"Number of peaks before clustering = {len(detected_peaks)}") + if clustering == Clustering.HDBSCAN.name: + return self._do_hdbscan_clustering(detected_peaks, **kwargs) + else: + return detected_peaks + def _do_hdbscan_clustering(self, peakdata, keep_ignored_labels=True, **kwargs): + """ + Do HDBSCAN clustering over the inferred peak coordinates + :param peakata: np array containig the inferred peak coordinates + :param keep_ignored_labels: whether to include the unclustered peaks in final result. + default is True, can be set to False via passing "keep_ignored_labels": False in kwargs + :param kwargs: variable keyword params to be passed to HDBSCAN algorithm + https://scikit-learn.org/1.5/modules/generated/sklearn.cluster.HDBSCAN.html + """ + peak_indices = np.delete(peakdata, [3,4], axis=1) + if ("keep_ignored_labels" in kwargs): + keep_ignored_labels = kwargs.pop("keep_ignored_labels") + + hdbscan_params = {"min_cluster_size": 2, + "min_samples": 2, + "store_centers" : "medoid", + "algorithm": "auto", + "cluster_selection_method": "eom", + "metric": "euclidean" + } + hdbscan_params.update(kwargs) + hdbscan = HDBSCAN(**hdbscan_params) + hdbscan.fit(peak_indices) + print(f"Silhouette score of the clusters={silhouette_score(peak_indices, hdbscan.labels_)}") + + if keep_ignored_labels: + selected_peak_indices = np.concatenate((hdbscan.medoids_, peak_indices[np.where(hdbscan.labels_==-1)]), axis=0) + else: + selected_peak_indices = hdbscan.medoids_ + confidence = [] + for peak in selected_peak_indices: + confidence.append(peakdata[np.where((peak_indices == peak).all(axis=1))[0].item(), -1]) + return np.column_stack((selected_peak_indices, confidence)) + + def _do_cnn_inferencing(self, workspace): data_set = WISHWorkspaceDataSet(workspace) data_loader = tc.utils.data.DataLoader(data_set, batch_size=self.batch_size, shuffle=False, num_workers=self.workers) @@ -71,9 +134,14 @@ def _do_cnn_inferencing(self, workspace): prediction = self.model([img.to(self.device)])[0] nms_prediction = self._apply_nms(prediction, self.iou_threshold) for box, score in zip(nms_prediction['boxes'], nms_prediction['scores']): + box = box.cpu().numpy().astype(int) tof = (box[0]+box[2])/2 tube_res = (box[1]+box[3])/2 - predicted_indices_with_score.append([tube_idx, tube_res.cpu(), tof.cpu(), score.cpu()]) + + boxsum = np.sum(img[0, box[1]:box[3], box[0]:box[2]].numpy()) + + predicted_indices_with_score.append([tube_idx, tube_res, tof, boxsum, score.cpu()]) + return data_set, np.array(predicted_indices_with_score) @@ -98,7 +166,7 @@ def _select_device(self): def _load_cnn_model_from_weights(self, weights_path): model = self._get_fasterrcnn_resnet50_fpn(num_classes=2) - model.load_state_dict(tc.load(weights_path, map_location=self.device)) + model.load_state_dict(tc.load(weights_path, map_location=self.device, weights_only=True)) return model.to(self.device) diff --git a/diffraction/WISH/bragg-detect/cnn/README.md b/diffraction/WISH/bragg-detect/cnn/README.md index 461e095..f374367 100644 --- a/diffraction/WISH/bragg-detect/cnn/README.md +++ b/diffraction/WISH/bragg-detect/cnn/README.md @@ -1,17 +1,29 @@ Bragg Peaks detection using a pre-trained Faster RCNN deep neural network ================ -Inorder to use the pre-trained Faster RCNN model inside mantid using an IDAaaS instance, below steps are required. +Inorder to run the pre-trained Faster RCNN model via mantid inside an IDAaaS instance, below steps are required. -* Launch an IDAaaS instance with GPUs from WISH > Wish Single Crystal GPU Advanced -* Launch Mantid workbench nightly from Applications->Software->Mantid->Mantid Workbench Nightly +* Launch an IDAaaS instance with GPUs selected from WISH > Wish Single Crystal GPU Advanced +* From IDAaaS, launch Mantid workbench nightly from Applications->Software->Mantid->Mantid Workbench Nightly * Download `scriptrepository\diffraction\WISH` directory from mantid's script repository as instructed here https://docs.mantidproject.org/nightly/workbench/scriptrepository.html * Check whether `\diffraction\WISH` path is listed under `Python Script Directories` tab from `File->Manage User Directories` of Mantid workbench. -* Below is an example code snippet to test the code. It will create a peaks workspace with the inferred peaks from the cnn and will do a peak filtering using the q_tol provided using `BaseSX.remove_duplicate_peaks_by_qlab`. +* Below is an example code snippet to use the pretrained model for Bragg peak detection. It will create a peaks workspace with the inferred peaks from the model. The valid values for the `clustering` argument are `QLab` or `HDBSCAN`. For `QLab` method the default value of `q_tol=0.05` will be used for `BaseSX.remove_duplicate_peaks_by_qlab` method. ```python from cnn.BraggDetectCNN import BraggDetectCNN model_weights = r'/mnt/ceph/auxiliary/wish/BraggDetect_FasterRCNN_Resnet50_Weights_v1.pt' cnn_peaks_detector = BraggDetectCNN(model_weights_path=model_weights, batch_size=64) -cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, q_tol=0.05) +cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="QLab") ``` * If the above import is not working, check whether the `\diffraction\WISH` path is listed under `Python Script Directories` tab from `File->Manage User Directories`. +* Depending on the selected `clustering` method in the above, the user can provide custom parameters using `kwargs` as shown below. +``` +kwargs={"q_tol": 0.1} +cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="QLab", **kwargs) + +or + +kwargs={"cluster_selection_method": "leaf", "algorithm": "brute", "keep_ignored_labels": False} +cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, clustering="HDBSCAN", **kwargs) +``` +* The documentation for using HDBSCAN can be found here: https://scikit-learn.org/1.5/modules/generated/sklearn.cluster.HDBSCAN.html +* The documentation for using `BaseSX.remove_duplicate_peaks_by_qlab` can be found here: https://docs.mantidproject.org/nightly/techniques/ISIS_SingleCrystalDiffraction_Workflow.html diff --git a/diffraction/WISH/bragg-detect/cnn/requirements.txt b/diffraction/WISH/bragg-detect/cnn/requirements.txt index 7a6670b..e382e8a 100644 --- a/diffraction/WISH/bragg-detect/cnn/requirements.txt +++ b/diffraction/WISH/bragg-detect/cnn/requirements.txt @@ -1,6 +1,4 @@ --f https://download.pytorch.org/whl/cu118 -torch -torchvision - +torch==2.5.1 +torchvision==0.20.1 albumentations==1.4.0 tqdm==4.66.3