Skip to content

Commit

Permalink
Data Analysis Label Fix for OD (#2365)
Browse files Browse the repository at this point in the history
* pandas operation fix

* cohort fix

* windows test fix

* backend & dataset fixes

* reverted test file

* data analysis label fix ckpt

* reverted changes from another PR

* auto lint fixes

* container support

* auto lint fixes

* reduced complexity

* refactor + auto lint fixes

* var fix

* nolabel var fix

* auto lint fixes
  • Loading branch information
Advitya17 authored Sep 29, 2023
1 parent 40af9df commit 7206fe1
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 189 deletions.
24 changes: 22 additions & 2 deletions libs/core-ui/src/lib/DatasetCohort.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@ import {
Operations
} from "./Interfaces/IFilter";
import { getPropertyValues } from "./util/datasetUtils/getPropertyValues";
import { IsBinary, IsMulticlass } from "./util/ExplanationUtils";
import { MulticlassClassificationEnum } from "./util/JointDatasetUtils";
import {
IsBinary,
IsMulticlass,
IsObjectDetection
} from "./util/ExplanationUtils";
import {
MulticlassClassificationEnum,
NoLabel
} from "./util/JointDatasetUtils";

export class DatasetCohort {
public selectedIndexes: number[] = [];
Expand Down Expand Up @@ -134,6 +141,12 @@ export class DatasetCohort {
dataDict[index][DatasetCohortColumns.PredictedY] = val;
}
});
this.dataset.objectDetectionLabels?.forEach((val, index) => {
dataDict[index][DatasetCohortColumns.ObjectDetectionIncorrect] =
val.incorrect;
dataDict[index][DatasetCohortColumns.ObjectDetectionCorrect] =
val.correct;
});
// set up errors
if (modelType === ModelTypes.Regression) {
for (const [index, row] of dataDict.entries()) {
Expand All @@ -160,6 +173,13 @@ export class DatasetCohort {
? MulticlassClassificationEnum.Misclassified
: MulticlassClassificationEnum.Correct;
}
} else if (modelType && IsObjectDetection(modelType)) {
for (const [index, row] of dataDict.entries()) {
dataDict[index][DatasetCohortColumns.ClassificationError] =
row[DatasetCohortColumns.ObjectDetectionIncorrect] !== NoLabel
? MulticlassClassificationEnum.Misclassified
: MulticlassClassificationEnum.Correct;
}
}
return dataDict;
}
Expand Down
2 changes: 2 additions & 0 deletions libs/core-ui/src/lib/DatasetCohortColumns.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ export enum DatasetCohortColumns {
Dataset = "Data",
PredictedY = "Predicted Y",
TrueY = "True Y",
ObjectDetectionIncorrect = "Incorrect",
ObjectDetectionCorrect = "Correct",
ClassificationError = "Classification outcome",
RegressionError = "Regression error",
ProbabilityY = "Probability Y"
Expand Down
2 changes: 1 addition & 1 deletion libs/core-ui/src/lib/Interfaces/IDataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export interface ITabularDatasetMetadata {
feature_ranges: Array<{ [key: string]: any }>;
}

interface IObjectDetectionLabelType {
export interface IObjectDetectionLabelType {
correct: string;
incorrect: string;
aggregate: string;
Expand Down
54 changes: 50 additions & 4 deletions libs/core-ui/src/lib/util/DatasetUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { ColumnActionsMode, IColumn } from "@fluentui/react";

import { JointDataset } from "../util/JointDataset";

import { MulticlassClassificationEnum } from "./JointDatasetUtils";

export interface ITableState {
rows: any[];
columns: IColumn[];
Expand All @@ -14,6 +16,12 @@ export function areRowPredTrueLabelsEqual(
row: { [key: string]: number },
jointDataset: JointDataset
): boolean {
if (jointDataset.hasODIncorrect && jointDataset.hasODCorrect) {
return (
row[JointDataset.ClassificationError] ===
MulticlassClassificationEnum.Correct
);
}
if (jointDataset.numLabels === 1) {
return row[JointDataset.PredictedYLabel] === row[JointDataset.TrueYLabel];
}
Expand Down Expand Up @@ -57,7 +65,15 @@ export function constructRows(
if (colors) {
tableRow.push(colors[i]);
}
if (jointDataset.hasTrueY) {
if (jointDataset.hasODCorrect) {
pushRowData(
tableRow,
JointDataset.ObjectDetectionCorrect,
jointDataset,
row,
index
);
} else if (jointDataset.hasTrueY) {
if (jointDataset.numLabels > 1) {
pushMultilabelRowData(
tableRow,
Expand All @@ -75,7 +91,15 @@ export function constructRows(
);
}
}
if (jointDataset.hasPredictedY) {
if (jointDataset.hasODIncorrect) {
pushRowData(
tableRow,
JointDataset.ObjectDetectionIncorrect,
jointDataset,
row,
index
);
} else if (jointDataset.hasPredictedY) {
if (jointDataset.numLabels > 1) {
pushMultilabelRowData(
tableRow,
Expand Down Expand Up @@ -147,7 +171,18 @@ export function constructCols(
});
index++;
}
if (!isCustomPointsView && jointDataset.hasTrueY) {
if (jointDataset.hasODCorrect) {
columns.push({
columnActionsMode: ColumnActionsMode.disabled,
fieldName: `${index}`,
isResizable: true,
key: `column${index}`,
maxWidth: 100,
minWidth: 50,
name: "Correct"
});
index++;
} else if (!isCustomPointsView && jointDataset.hasTrueY) {
columns.push({
columnActionsMode: ColumnActionsMode.disabled,
fieldName: `${index}`,
Expand All @@ -159,7 +194,18 @@ export function constructCols(
});
index++;
}
if (!isCustomPointsView && jointDataset.hasPredictedY) {
if (jointDataset.hasODIncorrect) {
columns.push({
columnActionsMode: ColumnActionsMode.disabled,
fieldName: `${index}`,
isResizable: true,
key: `column${index}`,
maxWidth: 100,
minWidth: 50,
name: "Incorrect"
});
index++;
} else if (!isCustomPointsView && jointDataset.hasPredictedY) {
columns.push({
columnActionsMode: ColumnActionsMode.disabled,
fieldName: `${index}`,
Expand Down
4 changes: 4 additions & 0 deletions libs/core-ui/src/lib/util/ExplanationUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ export function IsMultilabel(modelType: ModelTypes): boolean {
);
}

export function IsObjectDetection(modelType: ModelTypes): boolean {
return modelType === ModelTypes.ObjectDetection;
}

export function IsClassifier(modelType: ModelTypes): boolean {
return (
modelType === ModelTypes.Binary ||
Expand Down
Loading

0 comments on commit 7206fe1

Please sign in to comment.