Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update hyperparameters of bb methods #65

Merged
merged 4 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/run_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def explain_white_box(args):

# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "detection"
explanation.save(output, Path(args.image_path).stem)
output = Path(args.output) / "detection_white_box"
ori_image_name = Path(args.image_path).stem
explanation.save(output, f"{ori_image_name}_")


def explain_black_box(args):
Expand Down Expand Up @@ -131,7 +132,8 @@ def explain_black_box(args):
# Save saliency maps for visual inspection
if args.output is not None:
output = Path(args.output) / "detection_black_box"
explanation.save(output, f"{Path(args.image_path).stem}_")
ori_image_name = Path(args.image_path).stem
explanation.save(output, f"{ori_image_name}_")


def main(argv):
Expand Down
6 changes: 3 additions & 3 deletions openvino_xai/methods/black_box/aise/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,14 @@ def _preset_parameters(
kernel_widths: List[float] | np.ndarray | None,
) -> Tuple[int, np.ndarray]:
if preset == Preset.SPEED:
iterations = 25
iterations = 20
widths = np.linspace(0.1, 0.25, 3)
elif preset == Preset.BALANCE:
iterations = 50
widths = np.linspace(0.1, 0.25, 3)
elif preset == Preset.QUALITY:
iterations = 85
widths = np.linspace(0.075, 0.25, 4)
iterations = 50
widths = np.linspace(0.075, 0.25, 5)
else:
raise ValueError(f"Preset {preset} is not supported.")

Expand Down
6 changes: 3 additions & 3 deletions openvino_xai/methods/black_box/aise/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ def _preset_parameters(
divisors: List[float] | np.ndarray | None,
) -> Tuple[int, np.ndarray]:
if preset == Preset.SPEED:
iterations = 50
iterations = 20
divs = np.linspace(7, 1, 3)
elif preset == Preset.BALANCE:
iterations = 100
iterations = 50
divs = np.linspace(7, 1, 3)
elif preset == Preset.QUALITY:
iterations = 150
iterations = 50
divs = np.linspace(8, 1, 5)
else:
raise ValueError(f"Preset {preset} is not supported.")
Expand Down
27 changes: 17 additions & 10 deletions openvino_xai/methods/black_box/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def generate_saliency_map(
target_indices: List[int] | None = None,
preset: Preset = Preset.BALANCE,
num_masks: int | None = None,
num_cells: int = 8,
num_cells: int | None = None,
prob: float = 0.5,
seed: int = 0,
scale_output: bool = True,
Expand Down Expand Up @@ -84,7 +84,7 @@ def generate_saliency_map(
"""
data_preprocessed = self.preprocess_fn(data)

num_masks = self._preset_parameters(preset, num_masks)
num_masks, num_cells = self._preset_parameters(preset, num_masks, num_cells)

saliency_maps = self._run_synchronous_explanation(
data_preprocessed,
Expand All @@ -109,20 +109,27 @@ def generate_saliency_map(
def _preset_parameters(
preset: Preset,
num_masks: int | None = None,
) -> int:
# TODO (negvet): preset num_cells
if num_masks is not None:
return num_masks

num_cells: int | None = None,
) -> Tuple[int, int]:
if preset == Preset.SPEED:
return 2000
num_masks_ = 1000
num_cells_ = 4
elif preset == Preset.BALANCE:
return 5000
num_masks_ = 5000
num_cells_ = 8
elif preset == Preset.QUALITY:
return 8000
num_masks_ = 10000
num_cells_ = 12
else:
raise ValueError(f"Preset {preset} is not supported.")

if num_masks is None:
num_masks = num_masks_
if num_cells is None:
num_cells = num_cells_

return num_masks, num_cells

def _run_synchronous_explanation(
self,
data_preprocessed: np.ndarray,
Expand Down