You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I use the following code to implement adversarial classifier on cifar100
transform_train_cifar = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
traindataset = datasets.CIFAR10(root, transform_train_cifar)
trainloader = Dataloader(traindataset)
num_classes = 100
base_estimator = torchvision.models.resnet18(False)
base_estimator.avgpool = nn.AdaptiveAvgPool2d(1)
num_ftrs = base_estimator.fc.in_features
base_estimator.fc = nn.Linear(num_ftrs, num_classes)
ensemble = AdversarialTrainingClassifier(
estimator=base_estimator, # estimator is your pytorch model
n_estimators=args.num, # number of base estimators
cuda=True,
)
criterion = nn.CrossEntropyLoss()
ensemble.set_criterion(criterion)
# Set the optimizer
print('Setting optimizer...')
ensemble.set_optimizer(
"Adam", # type of parameter optimizer
lr=args.lr, # learning rate of parameter optimizer
weight_decay=args.weight_decay, # weight decay of parameter optimizer
)
# Set the learning rate scheduler
print('Setting scheduler...')
ensemble.set_scheduler(
"CosineAnnealingLR", # type of learning rate scheduler
T_max=args.epochs, # additional arguments on the scheduler
)
# Train the ensemble
print('Start training...')
ensemble.fit(
train_loader,
epochs=args.epochs,
)
but running this code gives the following error message:
Traceback (most recent call last):
File "train.py", line 251, in <module>
ensemble.fit(
File "/data/anaconda3/envs/mae/lib/python3.8/site-packages/torchensemble/adversarial_training.py", line 324, in fit
rets = parallel(
File "/data/anaconda3/envs/mae/lib/python3.8/site-packages/joblib/parallel.py", line 1085, in __call__
if self.dispatch_one_batch(iterator):
File "/data/anaconda3/envs/mae/lib/python3.8/site-packages/joblib/parallel.py", line 901, in dispatch_one_batch
self._dispatch(tasks)
File "/data/anaconda3/envs/mae/lib/python3.8/site-packages/joblib/parallel.py", line 819, in _dispatch
job = self._backend.apply_async(batch, callback=cb)
File "/data/anaconda3/envs/mae/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 208, in apply_async
result = ImmediateResult(func)
File "/data/anaconda3/envs/mae/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 597, in __init__
self.results = batch()
File "/data/anaconda3/envs/mae/lib/python3.8/site-packages/joblib/parallel.py", line 288, in __call__
return [func(*args, **kwargs)
File "/data/anaconda3/envs/mae/lib/python3.8/site-packages/joblib/parallel.py", line 288, in <listcomp>
return [func(*args, **kwargs)
File "/data/anaconda3/envs/mae/lib/python3.8/site-packages/torchensemble/adversarial_training.py", line 122, in _parallel_fit_per_epoch
adv_data = _get_fgsm_samples(data, epsilon, data_grad)
File "/data/anaconda3/envs/mae/lib/python3.8/site-packages/torchensemble/adversarial_training.py", line 176, in _get_fgsm_samples
raise ValueError(msg.format(min_value, max_value))
ValueError: The input range of samples passed to adversarial training should be in the range [0, 1], but got [-2.429, 2.754] instead.
Should I remove the normalization part in my data transformation? Thanks.
The text was updated successfully, but these errors were encountered:
I use the following code to implement adversarial classifier on cifar100
but running this code gives the following error message:
Should I remove the normalization part in my data transformation? Thanks.
The text was updated successfully, but these errors were encountered: