From 896237b995a44ea9151723c141aac5f1ac09d7dc Mon Sep 17 00:00:00 2001 From: Advitya Gemawat Date: Mon, 2 Oct 2023 11:42:00 -0400 Subject: [PATCH] Sorting support for Object Detection labels (#2376) * sort support ckpt * build fix * python lint fixes * mock data fix --- .../__mock_data__/fridgeObjectDetection.ts | 4 ++-- libs/core-ui/src/lib/DatasetCohort.ts | 2 +- .../Context/buildModelAssessmentContext.ts | 2 +- .../rai_vision_insights.py | 19 +++++++++++++------ 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/apps/dashboard/src/model-assessment-vision/__mock_data__/fridgeObjectDetection.ts b/apps/dashboard/src/model-assessment-vision/__mock_data__/fridgeObjectDetection.ts index 6266a3f4f8..ddd2dcc5bf 100644 --- a/apps/dashboard/src/model-assessment-vision/__mock_data__/fridgeObjectDetection.ts +++ b/apps/dashboard/src/model-assessment-vision/__mock_data__/fridgeObjectDetection.ts @@ -102,12 +102,12 @@ export const fridgeObjectDetection: IDataset = { object_detection_labels: [ { aggregate: "2 correct, 0 incorrect", - correct: "1 milk_bottle, 1 can", + correct: "1 can, 1 milk_bottle", incorrect: "(none)" }, { aggregate: "2 correct, 0 incorrect", - correct: "1 milk_bottle, 1 can", + correct: "1 can, 1 milk_bottle", incorrect: "(none)" }, { diff --git a/libs/core-ui/src/lib/DatasetCohort.ts b/libs/core-ui/src/lib/DatasetCohort.ts index 39d051396f..74f92049ed 100644 --- a/libs/core-ui/src/lib/DatasetCohort.ts +++ b/libs/core-ui/src/lib/DatasetCohort.ts @@ -141,7 +141,7 @@ export class DatasetCohort { dataDict[index][DatasetCohortColumns.PredictedY] = val; } }); - this.dataset.objectDetectionLabels?.forEach((val, index) => { + this.dataset.object_detection_labels?.forEach((val, index) => { dataDict[index][DatasetCohortColumns.ObjectDetectionIncorrect] = val.incorrect; dataDict[index][DatasetCohortColumns.ObjectDetectionCorrect] = diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts index feee0bc44a..a166e68a45 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts @@ -60,7 +60,7 @@ export function buildInitialModelAssessmentContext( featureMetaData: props.dataset.feature_metadata, localExplanations, metadata: modelMetadata, - objectDetectionLabels: props.dataset.objectDetectionLabels, + objectDetectionLabels: props.dataset.object_detection_labels, predictedProbabilities: props.dataset.probability_y, predictedY: props.dataset.predicted_y, targetColumn: props.dataset.target_column, diff --git a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py index 46e4ffce82..72b5814361 100644 --- a/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py +++ b/responsibleai_vision/responsibleai_vision/rai_vision_insights/rai_vision_insights.py @@ -738,19 +738,26 @@ def _generate_od_error_labels(self, true_y, pred_y, class_names): else: image_labels[_INCORRECT][object_label] += 1 - image_labels[_INCORRECT][object_label] += \ - np.count_nonzero( - error_matrix[label_idx] == - ErrorLabelType.DUPLICATE_DETECTION) + duplicate_detections = np.count_nonzero( + error_matrix[label_idx] == + ErrorLabelType.DUPLICATE_DETECTION) + if duplicate_detections > 0: + image_labels[_INCORRECT][object_label] += \ + duplicate_detections + + correct_labels = sorted(image_labels[_CORRECT].items(), + key=lambda x: class_names.index(x[0])) + incorrect_labels = sorted(image_labels[_INCORRECT].items(), + key=lambda x: class_names.index(x[0])) rendered_labels[_CORRECT] = ', '.join( f'{value} {key}' for key, value in - image_labels[_CORRECT].items() if value > 0) + correct_labels) if len(rendered_labels[_CORRECT]) == 0: rendered_labels[_CORRECT] = _NOLABEL rendered_labels[_INCORRECT] = ', '.join( f'{value} {key}' for key, value in - image_labels[_INCORRECT].items() if value > 0) + incorrect_labels) if len(rendered_labels[_INCORRECT]) == 0: rendered_labels[_INCORRECT] = _NOLABEL rendered_labels[_AGGREGATE_LABEL] = \