From 14db029494171268600914995369e0962f48c29a Mon Sep 17 00:00:00 2001 From: Matthew O'Malley-Nichols <91226873+onichmath@users.noreply.github.com> Date: Fri, 9 Aug 2024 22:57:52 -0700 Subject: [PATCH] Soft Non-Maximum Suppression (#2400) * Soft NMS with thresholds * NMS Test * Soft nms w/ boxes removed below threshold * Soft nms test * No longer removing bounding boxes to fit Soft-NMS focus * Initialize confidence * Added comments * Refactored out updating based on IOU/sigma * Score_threshold -> confidence_threshold for clarity * Remove bboxes below confidence threshold * Softnms basic functionality test * Softnms confidence decay test * Softnms confidence threshold test * Softnms no overlapping bbox test * Testing confidence after no overlap test * Single bbox and no bbox tests * Signify test completion * Handling result of test functions * Checking all pairs of bboxes instead of a forward pass * Equal confidence overlap test * Clarified tests for implementation * No longer dropping boxes, just setting to 0.0 * Formatted w/ cargo --- candle-transformers/src/object_detection.rs | 58 +++++ candle-transformers/tests/nms_tests.rs | 222 ++++++++++++++++++++ 2 files changed, 280 insertions(+) create mode 100644 candle-transformers/tests/nms_tests.rs diff --git a/candle-transformers/src/object_detection.rs b/candle-transformers/src/object_detection.rs index ce5793165..e922075fc 100644 --- a/candle-transformers/src/object_detection.rs +++ b/candle-transformers/src/object_detection.rs @@ -50,3 +50,61 @@ pub fn non_maximum_suppression(bboxes: &mut [Vec>], threshold: f32) { bboxes_for_class.truncate(current_index); } } + +// Updates confidences starting at highest and comparing subsequent boxes. +fn update_confidences( + bboxes_for_class: &[Bbox], + updated_confidences: &mut [f32], + iou_threshold: f32, + sigma: f32, +) { + let len = bboxes_for_class.len(); + for current_index in 0..len { + let current_bbox = &bboxes_for_class[current_index]; + for index in (current_index + 1)..len { + let iou_val = iou(current_bbox, &bboxes_for_class[index]); + if iou_val > iou_threshold { + // Decay calculation from page 4 of: https://arxiv.org/pdf/1704.04503 + let decay = (-iou_val * iou_val / sigma).exp(); + let updated_confidence = bboxes_for_class[index].confidence * decay; + updated_confidences[index] = updated_confidence; + } + } + } +} + +// Sorts the bounding boxes by confidence and applies soft non-maximum suppression. +// This function is based on the algorithm described in https://arxiv.org/pdf/1704.04503 +pub fn soft_non_maximum_suppression( + bboxes: &mut [Vec>], + iou_threshold: Option, + confidence_threshold: Option, + sigma: Option, +) { + let iou_threshold = iou_threshold.unwrap_or(0.5); + let confidence_threshold = confidence_threshold.unwrap_or(0.1); + let sigma = sigma.unwrap_or(0.5); + + for bboxes_for_class in bboxes.iter_mut() { + // Sort boxes by confidence in descending order + bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap()); + let mut updated_confidences = bboxes_for_class + .iter() + .map(|bbox| bbox.confidence) + .collect::>(); + update_confidences( + bboxes_for_class, + &mut updated_confidences, + iou_threshold, + sigma, + ); + // Update confidences, set to 0.0 if below threshold + for (i, &confidence) in updated_confidences.iter().enumerate() { + bboxes_for_class[i].confidence = if confidence < confidence_threshold { + 0.0 + } else { + confidence + }; + } + } +} diff --git a/candle-transformers/tests/nms_tests.rs b/candle-transformers/tests/nms_tests.rs new file mode 100644 index 000000000..d70f6fdf3 --- /dev/null +++ b/candle-transformers/tests/nms_tests.rs @@ -0,0 +1,222 @@ +use candle::Result; +use candle_transformers::object_detection::{ + non_maximum_suppression, soft_non_maximum_suppression, Bbox, +}; + +#[test] +fn nms_basic() -> Result<()> { + // Boxes based upon https://thepythoncode.com/article/non-maximum-suppression-using-opencv-in-python + let mut bboxes = vec![vec![ + Bbox { + xmin: 245.0, + ymin: 305.0, + xmax: 575.0, + ymax: 490.0, + confidence: 0.9, + data: (), + }, // Box 1 + Bbox { + xmin: 235.0, + ymin: 300.0, + xmax: 485.0, + ymax: 515.0, + confidence: 0.8, + data: (), + }, // Box 2 + Bbox { + xmin: 305.0, + ymin: 270.0, + xmax: 540.0, + ymax: 500.0, + confidence: 0.6, + data: (), + }, // Box 3 + ]]; + + non_maximum_suppression(&mut bboxes, 0.5); + let bboxes = bboxes.into_iter().next().unwrap(); + assert_eq!(bboxes.len(), 1); + assert_eq!(bboxes[0].confidence, 0.9); + + Ok(()) +} + +#[test] +fn softnms_basic_functionality() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.5, + data: (), + }, + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.9, + data: (), + }, + Bbox { + xmin: 0.2, + ymin: 0.2, + xmax: 1.2, + ymax: 1.2, + confidence: 0.6, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Should decay boxes following highest confidence box + assert!(bboxes[0][0].confidence == 0.9); + assert!(bboxes[0][1].confidence < 0.5); + assert!(bboxes[0][2].confidence < 0.6); + Ok(()) +} + +#[test] +fn softnms_confidence_decay() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }, // Reference box + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.8, + data: (), + }, // Overlapping box + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Check that confidence of the overlapping box is decayed + assert!(bboxes[0][0].confidence == 0.9); + assert!(bboxes[0][1].confidence < 0.8); + Ok(()) +} + +#[test] +fn softnms_confidence_threshold() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }, + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.05, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Box with confidence below the threshold should be removed + assert_eq!(bboxes[0].len(), 2); + assert_eq!(bboxes[0][0].confidence, 0.9); + assert_eq!(bboxes[0][1].confidence, 0.00); + Ok(()) +} + +#[test] +fn softnms_no_overlap() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }, + Bbox { + xmin: 2.0, + ymin: 2.0, + xmax: 3.0, + ymax: 3.0, + confidence: 0.8, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // Both boxes should remain as they do not significantly overlap + assert_eq!(bboxes[0].len(), 2); + assert_eq!(bboxes[0][0].confidence, 0.9); + assert_eq!(bboxes[0][1].confidence, 0.8); + Ok(()) +} +#[test] +fn softnms_no_bbox() -> Result<()> { + let mut bboxes: Vec>> = vec![]; + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + assert!(bboxes.is_empty()); + Ok(()) +} + +#[test] +fn softnms_single_bbox() -> Result<()> { + let mut bboxes = vec![vec![Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.9, + data: (), + }]]; + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + assert_eq!(bboxes[0].len(), 1); + Ok(()) +} + +#[test] +fn softnms_equal_confidence_overlap() -> Result<()> { + let mut bboxes = vec![vec![ + Bbox { + xmin: 0.0, + ymin: 0.0, + xmax: 1.0, + ymax: 1.0, + confidence: 0.5, + data: (), + }, + Bbox { + xmin: 0.1, + ymin: 0.1, + xmax: 1.1, + ymax: 1.1, + confidence: 0.5, + data: (), + }, + ]]; + + soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5)); + + // First box will be reference box, second box should be decayed + // Implementation must change to have both be decayed + assert_eq!(bboxes[0].len(), 2); + assert!(bboxes[0][0].confidence == 0.5); + assert!(bboxes[0][1].confidence < 0.5); + Ok(()) +}