Skip to content

Commit

Permalink
Sorting support for Object Detection labels (#2376)
Browse files Browse the repository at this point in the history
* sort support ckpt

* build fix

* python lint fixes

* mock data fix
  • Loading branch information
Advitya17 authored Oct 2, 2023
1 parent 7206fe1 commit 896237b
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ export const fridgeObjectDetection: IDataset = {
object_detection_labels: [
{
aggregate: "2 correct, 0 incorrect",
correct: "1 milk_bottle, 1 can",
correct: "1 can, 1 milk_bottle",
incorrect: "(none)"
},
{
aggregate: "2 correct, 0 incorrect",
correct: "1 milk_bottle, 1 can",
correct: "1 can, 1 milk_bottle",
incorrect: "(none)"
},
{
Expand Down
2 changes: 1 addition & 1 deletion libs/core-ui/src/lib/DatasetCohort.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ export class DatasetCohort {
dataDict[index][DatasetCohortColumns.PredictedY] = val;
}
});
this.dataset.objectDetectionLabels?.forEach((val, index) => {
this.dataset.object_detection_labels?.forEach((val, index) => {
dataDict[index][DatasetCohortColumns.ObjectDetectionIncorrect] =
val.incorrect;
dataDict[index][DatasetCohortColumns.ObjectDetectionCorrect] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ export function buildInitialModelAssessmentContext(
featureMetaData: props.dataset.feature_metadata,
localExplanations,
metadata: modelMetadata,
objectDetectionLabels: props.dataset.objectDetectionLabels,
objectDetectionLabels: props.dataset.object_detection_labels,
predictedProbabilities: props.dataset.probability_y,
predictedY: props.dataset.predicted_y,
targetColumn: props.dataset.target_column,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,19 +738,26 @@ def _generate_od_error_labels(self, true_y, pred_y, class_names):
else:
image_labels[_INCORRECT][object_label] += 1

image_labels[_INCORRECT][object_label] += \
np.count_nonzero(
error_matrix[label_idx] ==
ErrorLabelType.DUPLICATE_DETECTION)
duplicate_detections = np.count_nonzero(
error_matrix[label_idx] ==
ErrorLabelType.DUPLICATE_DETECTION)
if duplicate_detections > 0:
image_labels[_INCORRECT][object_label] += \
duplicate_detections

correct_labels = sorted(image_labels[_CORRECT].items(),
key=lambda x: class_names.index(x[0]))
incorrect_labels = sorted(image_labels[_INCORRECT].items(),
key=lambda x: class_names.index(x[0]))

rendered_labels[_CORRECT] = ', '.join(
f'{value} {key}' for key, value in
image_labels[_CORRECT].items() if value > 0)
correct_labels)
if len(rendered_labels[_CORRECT]) == 0:
rendered_labels[_CORRECT] = _NOLABEL
rendered_labels[_INCORRECT] = ', '.join(
f'{value} {key}' for key, value in
image_labels[_INCORRECT].items() if value > 0)
incorrect_labels)
if len(rendered_labels[_INCORRECT]) == 0:
rendered_labels[_INCORRECT] = _NOLABEL
rendered_labels[_AGGREGATE_LABEL] = \
Expand Down

0 comments on commit 896237b

Please sign in to comment.