This repository contains the PyTorch implementation of Constrained Parameter Regularization(CPR) with the Adam optimizer. CPR is an alternative to traditional weight decay. Unlike the uniform application of a single penalty, CPR enforces an upper bound on a statistical measure, such as the L2-norm, of individual parameter matrices. CPR introduces only a minor runtime overhead and only requires setting an upper bound (or does it automatically with an inflection point detection).
AdamCPR outperforms AdamW on various tasks, such as image classification (CIFAR100 and ImageNet) or language modeling finetuning or pretraining (GPT2/OpenWebText) as in the figure below.
We see the perplexity of GPT2s model training on OpenWebText with AdamW for 200k steps (blue) and 300k steps (purple) vs. AdamCPR with inflection point detection (green). The CPR model converges more linear and achieves a lower validation perplexity, equivalent to training 50% longer with AdamW. Please find more experiments in our paper.
With CPR, learning becomes a constraint optimization problem, which we tackle using an adaptation of the augmented Lagrangian method.
We implement this by adding a Lagrange multiplier 'uniform'
with a fixed value ,'depended'
on the initial parameter norm, 'warm_start'
based on the norm after X training steps and a 'inflection_point'
detection-based method which doesn't require any additional hyperparameter for the regularization.
We implement this Lagrange optimization directly in the Adam optimizer, which we call AdamCPR:
pip install pytorch-cpr
We implemented CPR with Adam optimizer in PyTorch (v2.3.1+). To use CPR, you can simply replace the optimizer in your training script with the AdamCPR optimizer.
from pytorch_cpr import AdamCPR
# for AdamCPR with warm start initialization
optimizer = AdamCPR(model, lr=0.001, kappa_init_param=1000, kappa_init_method='warm_start')
# for AdamCPR with inflection point initialization (no other regularization hyperparameter needed)
optimizer = AdamCPR(model, lr=0.001, kappa_init_method='inflection_point')
Parameter | Type | Default | Description |
---|---|---|---|
params |
iterable | required | Iterable of parameters to optimize or dicts defining parameter groups |
lr |
float | 1e-3 | Learning rate. Note: Tensor lr is only supported with capturable=True |
betas |
tuple(float, float) | (0.9, 0.999) | Coefficients for computing running averages of gradient and its square |
eps |
float | 1e-8 | Term added to denominator for numerical stability |
amsgrad |
bool | False | Whether to use the AMSGrad variant from "On the Convergence of Adam and Beyond" |
Parameter | Type | Default | Description |
---|---|---|---|
kappa_init_method |
str | 'inflection_point' | Method to initialize regularization bound. Options: • 'uniform' : Fixed value initialization• 'warm_start' : Delayed initialization• 'dependent' : Parameter-dependent initialization• 'inflection_point' : Automated inflection point detection-based initialization. |
kappa_init_param |
float | 1000.0 | Initial value for the regularization bound, the meaning depends on the initialization method: • 'uniform' : The value of the upper bound.• 'warm_start' : The number of steps before setting the upper bount to the current regularization value. • 'dependent' : The factor of the reg. value after initialization.• 'inflection_point' : No param. requiered. |
reg_function |
str | 'l2' | Regularization function type. Options: • 'l2' : L2 norm regularization• 'l1' : L1 norm regularization• 'std' : Standard deviation regularization• 'huber' : Huber norm regularization |
We provide scripts to replicate the experiments from our paper. Please use a system with at least 1 GPU. Install the package and the requirements for the example:
python3 -m venv venv
source venv/bin/activate
pip install -r examples/requirements.txt
pip install pytorch-cpr
The grokking experiment should run within a few minutes. The results will be saved in the grokking
folder.
To replicate the results in the paper, run variations with the following arguments:
python examples/train_grokking_task.py --optimizer adamw --weight_decay 0.1
python examples/train_grokking_task.py --optimizer adamw --weight_decay 0.0 --rescale 0.8
python examples/train_grokking_task.py --optimizer adamcpr --kappa_init_method dependent --kappa_init_param 0.8
The CIFAR-100 experiment should run within 20-30 minutes. The results will be saved in the cifar100
folder.
For AdamCPR with L2 norm as regularization function and kappa initialization depending on the parameter initialization:
python examples/train_cifar100_task.py --optimizer adamcpr --lr 0.001 --kappa_init_method dependent --kappa_init_param 1.0
python examples/train_cifar100_task.py --optimizer adamcpr --lr 0.001 --kappa_init_method warm_start --kappa_init_param 1000
For AdamAdaCPR with L2 norm as regularization function and kappa initialization with inflection point:
python examples/train_cifar100_task.py --optimizer adamcpr --lr 0.001 --kappa_init_method inflection_point
python examples/train_cifar100_task.py --optimizer adamw --lr 0.001 --weight_decay 0.001
python examples/train_cifar100_task.py --optimizer adamw --lr 0.001 --weight_decay 0 --rescale_alpha 0.8
python examples/train_cifar100_task.py --optimizer adam_awd --lr 0.001 --weight_decay 0.1
python examples/train_cifar100_task.py --optimizer adam_adadecay --lr 0.001 --weight_decay 0.1
Please cite our paper if you use CPR in your work:
@misc{franke2024cpr,
title={Improving Deep Learning Optimization through Constrained Parameter Regularization},
author={Jörg K. H. Franke and Michael Hefenbrock and Gregor Koehler and Frank Hutter},
journal={Advances in Neural Information Processing Systems},
volume={37},
year={2024},
}