Skip to content

Commit

Permalink
feat: object class filter
Browse files Browse the repository at this point in the history
  • Loading branch information
technolojin committed Mar 13, 2024
1 parent fcebec6 commit 3fd3469
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class Tracker
{
classification_ = classification;
}
void updateClassification(
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification);

private:
unique_identifier_msgs::msg::UUID uuid_;
Expand All @@ -51,6 +53,9 @@ class Tracker
int total_measurement_count_;
rclcpp::Time last_update_with_measurement_time_;

public:
autoware_auto_perception_msgs::msg::ObjectClassification last_filtered_class_;

public:
Tracker(
const rclcpp::Time & time,
Expand All @@ -68,6 +73,7 @@ class Tracker
{
return object_recognition_utils::getHighestProbLabel(classification_);
}
std::uint8_t getFilteredLabel() const { return last_filtered_class_.label; }
int getNoMeasurementCount() const { return no_measurement_count_; }
int getTotalNoMeasurementCount() const { return total_no_measurement_count_; }
int getTotalMeasurementCount() const { return total_measurement_count_; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ bool PedestrianAndBicycleTracker::measure(
pedestrian_tracker_.measure(object, time, self_transform);
bicycle_tracker_.measure(object, time, self_transform);
if (object_recognition_utils::getHighestProbLabel(object.classification) != Label::UNKNOWN)
setClassification(object.classification);
// setClassification(object.classification);
updateClassification(object.classification);
return true;
}

Expand Down
80 changes: 80 additions & 0 deletions perception/multi_object_tracker/src/tracker/model/tracker_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Tracker::Tracker(
std::mt19937 gen(std::random_device{}());
std::independent_bits_engine<std::mt19937, 8, uint8_t> bit_eng(gen);
std::generate(uuid_.uuid.begin(), uuid_.uuid.end(), bit_eng);

// initialize last_filtered_class_
last_filtered_class_ = object_recognition_utils::getHighestProbClassification(classification_);
}

bool Tracker::updateWithMeasurement(
Expand All @@ -54,6 +57,83 @@ bool Tracker::updateWithoutMeasurement()
return true;
}

void Tracker::updateClassification(
const std::vector<autoware_auto_perception_msgs::msg::ObjectClassification> & classification)
{
// Update classification
// 1. Match classification label
// 2. Update the matched classification probability with a gain
// 3. If the label is not found, add it to the classification list
// 4. If the old class probability is not found, decay the probability
// 5. Normalize the probability

const double gain = 0.05;
const double gain_inv = 1.0 - gain;
const double decay = gain_inv;

for (const auto & new_class : classification) {
bool found = false;
for (auto & old_class : classification_) {
// Update the matched classification probability with a gain
if (new_class.label == old_class.label) {
old_class.probability = old_class.probability * gain_inv + new_class.probability * gain;
found = true;
break;
}
}
// If the label is not found, add it to the classification list
if (!found) {
classification_.push_back(new_class);
}
}
// If the old class probability is not found, decay the probability
for (auto & old_class : classification_) {
bool found = false;
for (const auto & new_class : classification) {
if (new_class.label == old_class.label) {
found = true;
break;
}
}
if (!found) {
old_class.probability *= decay;
}
}

// Normalize
double sum = 0.0;
for (const auto & class_ : classification_) {
sum += class_.probability;
}
for (auto & class_ : classification_) {
class_.probability /= sum;
}

// If the probability is too small, remove the class
classification_.erase(
std::remove_if(
classification_.begin(), classification_.end(),
[](const auto & class_) { return class_.probability < 0.001; }),
classification_.end());

// Set the last filtered class
// if the highest probability class is not overcome a certain hysteresis, the last
// filtered class stays the same

for (const auto & class_ : classification_) {
if (class_.label == last_filtered_class_.label) {
last_filtered_class_.probability = class_.probability;
break;
}
}
const double hysteresis = 0.1;
autoware_auto_perception_msgs::msg::ObjectClassification const new_classification =
object_recognition_utils::getHighestProbClassification(classification_);
if (new_classification.probability > last_filtered_class_.probability + hysteresis) {
last_filtered_class_ = new_classification;
}
}

geometry_msgs::msg::PoseWithCovariance Tracker::getPoseWithCovariance(
const rclcpp::Time & time) const
{
Expand Down

0 comments on commit 3fd3469

Please sign in to comment.