Skip to content

Commit

Permalink
Documentation for device change
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobgil committed Dec 19, 2023
1 parent 00711a2 commit 51ae192
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 42 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ input_tensor = # Create an input tensor image for your model..
# Note: input_tensor can be a batch tensor with several images!

# Construct the CAM object once, and then re-use it on many images:
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=args.use_cuda)
cam = GradCAM(model=model, target_layers=target_layers)

# You can also use it within a with statement, to make sure it is freed,
# In case you need to re-create it inside an outer loop:
# with GradCAM(model=model, target_layers=target_layers, use_cuda=args.use_cuda) as cam:
# with GradCAM(model=model, target_layers=target_layers) as cam:
# ...

# We have to specify the target we want to generate
Expand Down Expand Up @@ -244,8 +244,8 @@ two smoothing methods are supported:
Usage: `python cam.py --image-path <path_to_image> --method <method> --output-dir <output_dir_path> `


To use with CUDA:
`python cam.py --image-path <path_to_image> --use-cuda --output-dir <output_dir_path> `
To use with a specific device, like cpu, cuda, cuda:0 or mps:
`python cam.py --image-path <path_to_image> --device cuda --output-dir <output_dir_path> `

----------

Expand Down
12 changes: 5 additions & 7 deletions cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import torch
from torchvision import models
from torchvision.models import ResNet50_Weights
from pytorch_grad_cam import (
GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
Expand All @@ -19,7 +18,7 @@

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default=None,
parser.add_argument('--device', type=str, default='cpu',
help='Torch device to use')
parser.add_argument(
'--image-path',
Expand Down Expand Up @@ -77,7 +76,7 @@ def get_args():
"gradcamelementwise": GradCAMElementWise
}

model = models.resnet50(weights=ResNet50_Weights.DEFAULT).to(args.device).eval()
model = models.resnet50(pretrained=True).to(torch.device(args.device)).eval()

# Choose the target layer you want to compute the visualization for.
# Usually this will be the last convolutional layer in the model.
Expand All @@ -104,16 +103,15 @@ def get_args():
# the Class Activation Maps for.
# If targets is None, the highest scoring category (for every member in the batch) will be used.
# You can target specific categories by
# targets = [e.g ClassifierOutputTarget(281)]
# targets = [ClassifierOutputTarget(281)]
# targets = [ClassifierOutputTarget(281)]
targets = None

# Using the with statement ensures the context is freed, and you can
# recreate different CAM objects in a loop.
cam_algorithm = methods[args.method]
with cam_algorithm(model=model,
target_layers=target_layers,
device=args.device) as cam:

target_layers=target_layers) as cam:

# AblationCAM and ScoreCAM have batched implementations.
# You can override the internal batch size for faster computation.
Expand Down
3 changes: 2 additions & 1 deletion pytorch_grad_cam/base_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ def __init__(self,
tta_transforms: Optional[tta.Compose] = None) -> None:
self.model = model.eval()
self.target_layers = target_layers
self.device = next(self.model.parameters()).device

# Use the same device as the model.
self.device = next(self.model.parameters()).device
self.reshape_transform = reshape_transform
self.compute_input_gradient = compute_input_gradient
self.uses_gradients = uses_gradients
Expand Down
2 changes: 1 addition & 1 deletion pytorch_grad_cam/score_cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_cam_weights(self,
upsample = torch.nn.UpsamplingBilinear2d(
size=input_tensor.shape[-2:])
activation_tensor = torch.from_numpy(activations)
activation_tensor = activation_tensor.to(next(self.model.parameters()).device)
activation_tensor = activation_tensor.to(self.device)

upsampled = upsample(activation_tensor)

Expand Down
58 changes: 29 additions & 29 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import setuptools

with open('README.md', mode='r', encoding='utf-8') as fh:
long_description = fh.read()

with open("requirements.txt", "r") as f:
requirements = f.readlines()

setuptools.setup(
name='grad-cam',
version='1.4.8',
author='Jacob Gildenblat',
author_email='[email protected]',
description='Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/jacobgil/pytorch-grad-cam',
project_urls={
'Bug Tracker': 'https://github.com/jacobgil/pytorch-grad-cam/issues',
},
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
],
packages=setuptools.find_packages(
exclude=["*tutorials*"]),
python_requires='>=3.6',
install_requires=requirements)
import setuptools

with open('README.md', mode='r', encoding='utf-8') as fh:
long_description = fh.read()

with open("requirements.txt", "r") as f:
requirements = f.readlines()

setuptools.setup(
name='grad-cam',
version='1.5.0',
author='Jacob Gildenblat',
author_email='[email protected]',
description='Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/jacobgil/pytorch-grad-cam',
project_urls={
'Bug Tracker': 'https://github.com/jacobgil/pytorch-grad-cam/issues',
},
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
],
packages=setuptools.find_packages(
exclude=["*tutorials*"]),
python_requires='>=3.8',
install_requires=requirements)

0 comments on commit 51ae192

Please sign in to comment.