Skip to content

Commit

Permalink
Refactor torch-onnx example
Browse files Browse the repository at this point in the history
  • Loading branch information
goodsong81 committed Sep 12, 2024
1 parent cc5646c commit 202b049
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 48 deletions.
2 changes: 1 addition & 1 deletion docs/source/user-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ saliency_maps = saliency_maps.numpy(force=True).squeeze(0) # Cxhxw
saliency_map = saliency_maps[label] # hxw saliency_map for the label
saliency_map = colormap(saliency_map[None, :]) # 1xhxw
saliency_map = cv2.resize(saliency_map.squeeze(0), dsize=input_size) # HxW
saliency_image = overlay(saliency_map, image)
result_image = overlay(saliency_map, image)
```

## XAI method overview
Expand Down
158 changes: 111 additions & 47 deletions examples/run_torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ def get_argument_parser():
return parser


def main(argv: list[str]):
parser = get_argument_parser()
args = parser.parse_args(argv)
def run_insert_xai_torch(args: list[str]):
"""Insert XAI head into PyTorch model and run inference on PyTorch Runtime to get saliency map."""

# Load Torch model from timm
try:
Expand All @@ -37,7 +36,7 @@ def main(argv: list[str]):
except Exception as e:
logger.error(e)
logger.info(f"Please choose from {timm.list_models()}")
sys.exit(1)
return
input_size = model.default_cfg["input_size"][1:] # (H, W)
input_mean = np.array(model.default_cfg["mean"])
input_std = np.array(model.default_cfg["std"])
Expand Down Expand Up @@ -76,52 +75,48 @@ def main(argv: list[str]):
saliency_map = saliency_maps[label] # hxw saliency_map for the label
saliency_map = colormap(saliency_map[None, :]) # 1xhxw
saliency_map = cv2.resize(saliency_map.squeeze(0), dsize=input_size) # HxW
saliency_image = overlay(saliency_map, image)
saliency_image = cv2.cvtColor(saliency_image, code=cv2.COLOR_RGB2BGR)
saliency_image_path = Path(args.output_dir) / "xai-torch.png"
saliency_image_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(saliency_image_path, saliency_image)
logger.info(f"Torch XAI model saliency map: {saliency_image_path}")
result_image = overlay(saliency_map, image)
result_image = cv2.cvtColor(result_image, code=cv2.COLOR_RGB2BGR)
result_image_path = Path(args.output_dir) / "xai-torch.png"
result_image_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(result_image_path, result_image)
logger.info(f"Torch XAI model saliency map: {result_image_path}")

# OpenVINO model conversion
ov_model = ov.convert_model(
model_xai,
example_input=torch.from_numpy(image_norm),
input=(ov.PartialShape([-1, *image_norm.shape[1:]],))
)
model_path = Path(args.output_dir) / "model.xml"
model_path.parent.mkdir(parents=True, exist_ok=True)
ov.save_model(ov_model, model_path)
logger.info(f"OpenVINO XAI model: {model_path}")

# OpenVINO XAI model inference
ov_model = ov.Core().compile_model(ov_model, device_name="CPU")
outputs = ov_model(image_norm)
logits = outputs["prediction"] # BxC
saliency_maps = outputs["saliency_map"] # BxCxhxw
probs = softmax(logits)
label = probs.argmax(axis=-1)[0]
logger.info(f"OpenVINO XAI model prediction: classes ({probs.shape[-1]}) -> label ({label}) -> prob ({probs[0, label]})")

# OpenVINO XAI model saliency map
saliency_maps = saliency_maps.squeeze(0) # Cxhxw
saliency_map = saliency_maps[label] # hxw saliency_map for the label
saliency_map = colormap(saliency_map[None, :]) # 1xhxw
saliency_map = cv2.resize(saliency_map.squeeze(0), dsize=input_size) # HxW
saliency_image = overlay(saliency_map, image)
saliency_image = cv2.cvtColor(saliency_image, code=cv2.COLOR_RGB2BGR)
saliency_image_path = Path(args.output_dir) / "xai-openvino.png"
saliency_image_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(saliency_image_path, saliency_image)
logger.info(f"OpenVINO XAI model saliency map: {saliency_image_path}")
def run_insert_xai_torch_to_onnx(args: list[str]):
"""Insert XAI head into PyTorch model, then converto to ONNX format and run inference on ONNX Runtime to get saliency map."""

# ONNX import
try:
importlib.import_module("onnx")
onnxruntime = importlib.import_module("onnxruntime")
except Exception:
logger.info("Please install onnx and onnxruntime package to run ONNX XAI example.")
sys.exit(0)
return

# Load Torch model from timm
try:
model = timm.create_model(args.model_name, in_chans=3, pretrained=True)
logger.info(f"Model config: {model.default_cfg}")
logger.info(f"Model layers: {model}")
except Exception as e:
logger.error(e)
logger.info(f"Please choose from {timm.list_models()}")
return
input_size = model.default_cfg["input_size"][1:] # (H, W)
input_mean = np.array(model.default_cfg["mean"])
input_std = np.array(model.default_cfg["std"])

# Load image
image = cv2.imread("tests/assets/cheetah_person.jpg")
image = cv2.resize(image, dsize=input_size)
image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB)
image_norm = ((image/255.0 - input_mean)/input_std).astype(np.float32)
image_norm = image_norm.transpose((2, 0, 1)) # HxWxC -> CxHxW
image_norm = image_norm[None, :] # CxHxW -> 1xCxHxW

# Insert XAI head
model_xai: torch.nn.Module = insert_xai(model, Task.CLASSIFICATION)

# ONNX model conversion
model_path = Path(args.output_dir) / "model.onnx"
Expand Down Expand Up @@ -151,12 +146,81 @@ def main(argv: list[str]):
saliency_map = saliency_maps[label] # hxw saliency_map for the label
saliency_map = colormap(saliency_map[None, :]) # 1xhxw
saliency_map = cv2.resize(saliency_map.squeeze(0), dsize=input_size) # HxW
saliency_image = overlay(saliency_map, image)
saliency_image = cv2.cvtColor(saliency_image, code=cv2.COLOR_RGB2BGR)
saliency_image_path = Path(args.output_dir) / "xai-onnx.png"
saliency_image_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(saliency_image_path, saliency_image)
logger.info(f"ONNX XAI model saliency map: {saliency_image_path}")
result_image = overlay(saliency_map, image)
result_image = cv2.cvtColor(result_image, code=cv2.COLOR_RGB2BGR)
result_image_path = Path(args.output_dir) / "xai-onnx.png"
result_image_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(result_image_path, result_image)
logger.info(f"ONNX XAI model saliency map: {result_image_path}")


def run_insert_xai_torch_to_openvino(args: list[str]):
"""Insert XAI head into PyTorch model, then convert to OpenVINO format and run inference on OpenVINO Runtime to get saliency map."""

# Load Torch model from timm
try:
model = timm.create_model(args.model_name, in_chans=3, pretrained=True)
logger.info(f"Model config: {model.default_cfg}")
logger.info(f"Model layers: {model}")
except Exception as e:
logger.error(e)
logger.info(f"Please choose from {timm.list_models()}")
return
input_size = model.default_cfg["input_size"][1:] # (H, W)
input_mean = np.array(model.default_cfg["mean"])
input_std = np.array(model.default_cfg["std"])

# Load image
image = cv2.imread("tests/assets/cheetah_person.jpg")
image = cv2.resize(image, dsize=input_size)
image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB)
image_norm = ((image/255.0 - input_mean)/input_std).astype(np.float32)
image_norm = image_norm.transpose((2, 0, 1)) # HxWxC -> CxHxW
image_norm = image_norm[None, :] # CxHxW -> 1xCxHxW

# Insert XAI head
model_xai: torch.nn.Module = insert_xai(model, Task.CLASSIFICATION)

# OpenVINO model conversion
ov_model = ov.convert_model(
model_xai,
example_input=torch.from_numpy(image_norm),
input=(ov.PartialShape([-1, *image_norm.shape[1:]],))
)
model_path = Path(args.output_dir) / "model.xml"
model_path.parent.mkdir(parents=True, exist_ok=True)
ov.save_model(ov_model, model_path)
logger.info(f"OpenVINO XAI model: {model_path}")

# OpenVINO XAI model inference
ov_model = ov.Core().compile_model(ov_model, device_name="CPU")
outputs = ov_model(image_norm)
logits = outputs["prediction"] # BxC
saliency_maps = outputs["saliency_map"] # BxCxhxw
probs = softmax(logits)
label = probs.argmax(axis=-1)[0]
logger.info(f"OpenVINO XAI model prediction: classes ({probs.shape[-1]}) -> label ({label}) -> prob ({probs[0, label]})")

# OpenVINO XAI model saliency map
saliency_maps = saliency_maps.squeeze(0) # Cxhxw
saliency_map = saliency_maps[label] # hxw saliency_map for the label
saliency_map = colormap(saliency_map[None, :]) # 1xhxw
saliency_map = cv2.resize(saliency_map.squeeze(0), dsize=input_size) # HxW
result_image = overlay(saliency_map, image)
result_image = cv2.cvtColor(result_image, code=cv2.COLOR_RGB2BGR)
result_image_path = Path(args.output_dir) / "xai-openvino.png"
result_image_path.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(result_image_path, result_image)
logger.info(f"OpenVINO XAI model saliency map: {result_image_path}")


def main(argv: list[str]):
parser = get_argument_parser()
args = parser.parse_args(argv)

run_insert_xai_torch(args)
run_insert_xai_torch_to_onnx(args)
run_insert_xai_torch_to_openvino(args)


if __name__ == "__main__":
Expand Down

0 comments on commit 202b049

Please sign in to comment.