Skip to content

Commit

Permalink
Soft Non-Maximum Suppression (#2400)
Browse files Browse the repository at this point in the history
* Soft NMS with thresholds

* NMS Test

* Soft nms w/ boxes removed below threshold

* Soft nms test

* No longer removing bounding boxes to fit Soft-NMS focus

* Initialize confidence

* Added comments

* Refactored out updating based on IOU/sigma

* Score_threshold -> confidence_threshold for clarity

* Remove bboxes below confidence threshold

* Softnms basic functionality test

* Softnms confidence decay test

* Softnms confidence threshold test

* Softnms no overlapping bbox test

* Testing confidence after no overlap test

* Single bbox and no bbox tests

* Signify test completion

* Handling result of test functions

* Checking all pairs of bboxes instead of a forward pass

* Equal confidence overlap test

* Clarified tests for implementation

* No longer dropping boxes, just setting to 0.0

* Formatted w/ cargo
  • Loading branch information
onichmath authored Aug 10, 2024
1 parent 6e6c1c9 commit 14db029
Show file tree
Hide file tree
Showing 2 changed files with 280 additions and 0 deletions.
58 changes: 58 additions & 0 deletions candle-transformers/src/object_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,61 @@ pub fn non_maximum_suppression<D>(bboxes: &mut [Vec<Bbox<D>>], threshold: f32) {
bboxes_for_class.truncate(current_index);
}
}

// Updates confidences starting at highest and comparing subsequent boxes.
fn update_confidences<D>(
bboxes_for_class: &[Bbox<D>],
updated_confidences: &mut [f32],
iou_threshold: f32,
sigma: f32,
) {
let len = bboxes_for_class.len();
for current_index in 0..len {
let current_bbox = &bboxes_for_class[current_index];
for index in (current_index + 1)..len {
let iou_val = iou(current_bbox, &bboxes_for_class[index]);
if iou_val > iou_threshold {
// Decay calculation from page 4 of: https://arxiv.org/pdf/1704.04503
let decay = (-iou_val * iou_val / sigma).exp();
let updated_confidence = bboxes_for_class[index].confidence * decay;
updated_confidences[index] = updated_confidence;
}
}
}
}

// Sorts the bounding boxes by confidence and applies soft non-maximum suppression.
// This function is based on the algorithm described in https://arxiv.org/pdf/1704.04503
pub fn soft_non_maximum_suppression<D>(
bboxes: &mut [Vec<Bbox<D>>],
iou_threshold: Option<f32>,
confidence_threshold: Option<f32>,
sigma: Option<f32>,
) {
let iou_threshold = iou_threshold.unwrap_or(0.5);
let confidence_threshold = confidence_threshold.unwrap_or(0.1);
let sigma = sigma.unwrap_or(0.5);

for bboxes_for_class in bboxes.iter_mut() {
// Sort boxes by confidence in descending order
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut updated_confidences = bboxes_for_class
.iter()
.map(|bbox| bbox.confidence)
.collect::<Vec<_>>();
update_confidences(
bboxes_for_class,
&mut updated_confidences,
iou_threshold,
sigma,
);
// Update confidences, set to 0.0 if below threshold
for (i, &confidence) in updated_confidences.iter().enumerate() {
bboxes_for_class[i].confidence = if confidence < confidence_threshold {
0.0
} else {
confidence
};
}
}
}
222 changes: 222 additions & 0 deletions candle-transformers/tests/nms_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
use candle::Result;
use candle_transformers::object_detection::{
non_maximum_suppression, soft_non_maximum_suppression, Bbox,
};

#[test]
fn nms_basic() -> Result<()> {
// Boxes based upon https://thepythoncode.com/article/non-maximum-suppression-using-opencv-in-python
let mut bboxes = vec![vec![
Bbox {
xmin: 245.0,
ymin: 305.0,
xmax: 575.0,
ymax: 490.0,
confidence: 0.9,
data: (),
}, // Box 1
Bbox {
xmin: 235.0,
ymin: 300.0,
xmax: 485.0,
ymax: 515.0,
confidence: 0.8,
data: (),
}, // Box 2
Bbox {
xmin: 305.0,
ymin: 270.0,
xmax: 540.0,
ymax: 500.0,
confidence: 0.6,
data: (),
}, // Box 3
]];

non_maximum_suppression(&mut bboxes, 0.5);
let bboxes = bboxes.into_iter().next().unwrap();
assert_eq!(bboxes.len(), 1);
assert_eq!(bboxes[0].confidence, 0.9);

Ok(())
}

#[test]
fn softnms_basic_functionality() -> Result<()> {
let mut bboxes = vec![vec![
Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.5,
data: (),
},
Bbox {
xmin: 0.1,
ymin: 0.1,
xmax: 1.1,
ymax: 1.1,
confidence: 0.9,
data: (),
},
Bbox {
xmin: 0.2,
ymin: 0.2,
xmax: 1.2,
ymax: 1.2,
confidence: 0.6,
data: (),
},
]];

soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));

// Should decay boxes following highest confidence box
assert!(bboxes[0][0].confidence == 0.9);
assert!(bboxes[0][1].confidence < 0.5);
assert!(bboxes[0][2].confidence < 0.6);
Ok(())
}

#[test]
fn softnms_confidence_decay() -> Result<()> {
let mut bboxes = vec![vec![
Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.9,
data: (),
}, // Reference box
Bbox {
xmin: 0.1,
ymin: 0.1,
xmax: 1.1,
ymax: 1.1,
confidence: 0.8,
data: (),
}, // Overlapping box
]];

soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));

// Check that confidence of the overlapping box is decayed
assert!(bboxes[0][0].confidence == 0.9);
assert!(bboxes[0][1].confidence < 0.8);
Ok(())
}

#[test]
fn softnms_confidence_threshold() -> Result<()> {
let mut bboxes = vec![vec![
Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.9,
data: (),
},
Bbox {
xmin: 0.1,
ymin: 0.1,
xmax: 1.1,
ymax: 1.1,
confidence: 0.05,
data: (),
},
]];

soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));

// Box with confidence below the threshold should be removed
assert_eq!(bboxes[0].len(), 2);
assert_eq!(bboxes[0][0].confidence, 0.9);
assert_eq!(bboxes[0][1].confidence, 0.00);
Ok(())
}

#[test]
fn softnms_no_overlap() -> Result<()> {
let mut bboxes = vec![vec![
Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.9,
data: (),
},
Bbox {
xmin: 2.0,
ymin: 2.0,
xmax: 3.0,
ymax: 3.0,
confidence: 0.8,
data: (),
},
]];

soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));

// Both boxes should remain as they do not significantly overlap
assert_eq!(bboxes[0].len(), 2);
assert_eq!(bboxes[0][0].confidence, 0.9);
assert_eq!(bboxes[0][1].confidence, 0.8);
Ok(())
}
#[test]
fn softnms_no_bbox() -> Result<()> {
let mut bboxes: Vec<Vec<Bbox<()>>> = vec![];
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
assert!(bboxes.is_empty());
Ok(())
}

#[test]
fn softnms_single_bbox() -> Result<()> {
let mut bboxes = vec![vec![Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.9,
data: (),
}]];
soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));
assert_eq!(bboxes[0].len(), 1);
Ok(())
}

#[test]
fn softnms_equal_confidence_overlap() -> Result<()> {
let mut bboxes = vec![vec![
Bbox {
xmin: 0.0,
ymin: 0.0,
xmax: 1.0,
ymax: 1.0,
confidence: 0.5,
data: (),
},
Bbox {
xmin: 0.1,
ymin: 0.1,
xmax: 1.1,
ymax: 1.1,
confidence: 0.5,
data: (),
},
]];

soft_non_maximum_suppression(&mut bboxes, Some(0.5), Some(0.1), Some(0.5));

// First box will be reference box, second box should be decayed
// Implementation must change to have both be decayed
assert_eq!(bboxes[0].len(), 2);
assert!(bboxes[0][0].confidence == 0.5);
assert!(bboxes[0][1].confidence < 0.5);
Ok(())
}

0 comments on commit 14db029

Please sign in to comment.