Skip to content

Commit

Permalink
feat(multi_object_tracker): add object class filtering in tracking pr…
Browse files Browse the repository at this point in the history
…ocess (autowarefoundation#6607)

* feat: object class filter

Signed-off-by: Taekjin LEE <[email protected]>

* fix: set a member private

Signed-off-by: Taekjin LEE <[email protected]>

* fix: last filtered label is not useful, remove

Signed-off-by: Taekjin LEE <[email protected]>

* style(pre-commit): autofix

Signed-off-by: Taekjin LEE <[email protected]>

* fix: multiply gain for new class

Signed-off-by: Taekjin LEE <[email protected]>

* style(pre-commit): autofix

Signed-off-by: Taekjin LEE <[email protected]>

* chore: algorithm explanation

Signed-off-by: Taekjin LEE <[email protected]>

* fix: revise the filtering process flow

Signed-off-by: Taekjin LEE <[email protected]>

---------

Signed-off-by: Taekjin LEE <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
technolojin and pre-commit-ci[bot] authored Mar 15, 2024
1 parent 5a6c82e commit 7ba8016
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class Tracker
{
classification_ = classification;
}
void updateClassification(
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification);

private:
unique_identifier_msgs::msg::UUID uuid_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ bool MultipleVehicleTracker::measure(
big_vehicle_tracker_.measure(object, time, self_transform);
normal_vehicle_tracker_.measure(object, time, self_transform);
if (object_recognition_utils::getHighestProbLabel(object.classification) != Label::UNKNOWN)
setClassification(object.classification);
updateClassification(object.classification);
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ bool PedestrianAndBicycleTracker::measure(
pedestrian_tracker_.measure(object, time, self_transform);
bicycle_tracker_.measure(object, time, self_transform);
if (object_recognition_utils::getHighestProbLabel(object.classification) != Label::UNKNOWN)
setClassification(object.classification);
updateClassification(object.classification);
return true;
}

Expand Down
61 changes: 61 additions & 0 deletions perception/multi_object_tracker/src/tracker/model/tracker_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,67 @@ bool Tracker::updateWithoutMeasurement()
return true;
}

void Tracker::updateClassification(
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification)
{
// classification algorithm:
// 0. Normalize the input classification
// 1-1. Update the matched classification probability with a gain (ratio of 0.05)
// 1-2. If the label is not found, add it to the classification list
// 2. Remove the class with probability < remove_threshold (0.001)
// 3. Normalize tracking classification

// Parameters
// if the remove_threshold is too high (compare to the gain), the classification will be removed
// immediately
const double gain = 0.05;
constexpr double remove_threshold = 0.001;

// Normalization function
auto normalizeProbabilities =
[](std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification) {
double sum = 0.0;
for (const auto & class_ : classification) {
sum += class_.probability;
}
for (auto & class_ : classification) {
class_.probability /= sum;
}
};

// Normalize the input
auto classification_input = classification;
normalizeProbabilities(classification_input);

// Update the matched classification probability with a gain
for (const auto & new_class : classification_input) {
bool found = false;
for (auto & old_class : classification_) {
if (new_class.label == old_class.label) {
old_class.probability += new_class.probability * gain;
found = true;
break;
}
}
// If the label is not found, add it to the classification list
if (!found) {
auto adding_class = new_class;
adding_class.probability *= gain;
classification_.push_back(adding_class);
}
}

// If the probability is less than the threshold, remove the class
classification_.erase(
std::remove_if(
classification_.begin(), classification_.end(),
[remove_threshold](const auto & class_) { return class_.probability < remove_threshold; }),
classification_.end());

// Normalize tracking classification
normalizeProbabilities(classification_);
}

geometry_msgs::msg::PoseWithCovariance Tracker::getPoseWithCovariance(
const rclcpp::Time & time) const
{
Expand Down

0 comments on commit 7ba8016

Please sign in to comment.