Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ntianhe ren committed Jan 31, 2023
2 parents dc4b886 + faff7fe commit 40043aa
Show file tree
Hide file tree
Showing 10 changed files with 372 additions and 0 deletions.
184 changes: 184 additions & 0 deletions configs/common/common_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from fvcore.common.param_scheduler import (
MultiStepParamScheduler,
StepParamScheduler,
StepWithFixedGammaParamScheduler,
ConstantParamScheduler,
CosineParamScheduler,
LinearParamScheduler,
ExponentialParamScheduler,
)

from detectron2.config import LazyCall as L
from detectron2.solver import WarmupParamScheduler


def multistep_lr_scheduler(
values=[1.0, 0.1],
warmup_steps=0,
num_updates=90000,
milestones=[82500],
warmup_method="linear",
warmup_factor=0.001,
):

# total steps default to num_updates, if None, will use milestones[-1].
if num_updates is None:
total_steps = milestones[-1]
else:
total_steps = num_updates

# define multi-step scheduler
scheduler = L(MultiStepParamScheduler)(
values=values,
milestones=milestones,
num_updates=num_updates,
)

# wrap with warmup scheduler
return L(WarmupParamScheduler)(
scheduler=scheduler,
warmup_length=warmup_steps / total_steps,
warmup_method=warmup_method,
warmup_factor=warmup_factor,
)


def step_lr_scheduler(
values,
warmup_steps,
num_updates,
warmup_method="linear",
warmup_factor=0.001,
):

# define step scheduler
scheduler = L(StepParamScheduler)(
values=values,
num_updates=num_updates
)

# wrap with warmup scheduler
return L(WarmupParamScheduler)(
scheduler=scheduler,
warmup_length=warmup_steps / num_updates,
warmup_method=warmup_method,
warmup_factor=warmup_factor,
)


def step_lr_scheduler_with_fixed_gamma(
base_value,
num_decays,
gamma,
num_updates,
warmup_steps,
warmup_method="linear",
warmup_factor=0.001,
):

# define step scheduler with fixed gamma
scheduler = L(StepWithFixedGammaParamScheduler)(
base_value=base_value,
num_decays=num_decays,
gamma=gamma,
num_updates=num_updates,
)

# wrap with warmup scheduler
return L(WarmupParamScheduler)(
scheduler=scheduler,
warmup_length=warmup_steps / num_updates,
warmup_method=warmup_method,
warmup_factor=warmup_factor,
)


def cosine_lr_scheduler(
start_value,
end_value,
num_updates,
warmup_steps,
warmup_method="linear",
warmup_factor=0.001,
):

# define cosine scheduler
scheduler = L(CosineParamScheduler)(
start_value=start_value,
end_value=end_value,
)

# wrap with warmup scheduler
return L(WarmupParamScheduler)(
scheduler=scheduler,
warmup_length=warmup_steps / num_updates,
warmup_method=warmup_method,
warmup_factor=warmup_factor,
)

def linear_lr_scheduler(
start_value,
end_value,
num_updates,
warmup_steps,
warmup_method="linear",
warmup_factor=0.001,
):

# define linear scheduler
scheduler = L(LinearParamScheduler)(
start_value=start_value,
end_value=end_value,
)

# wrap with warmup scheduler
return L(WarmupParamScheduler)(
scheduler=scheduler,
warmup_length=warmup_steps / num_updates,
warmup_method=warmup_method,
warmup_factor=warmup_factor,
)

def constant_lr_scheduler(
value,
num_updates,
warmup_steps,
warmup_method="linear",
warmup_factor=0.001,
):

# define constant scheduler
scheduler = L(ConstantParamScheduler)(
value=value
)

# wrap with warmup scheduler
return L(WarmupParamScheduler)(
scheduler=scheduler,
warmup_length=warmup_steps / num_updates,
warmup_method=warmup_method,
warmup_factor=warmup_factor,
)

def exponential_lr_scheduler(
start_value,
decay,
num_updates,
warmup_steps,
warmup_method="linear",
warmup_factor=0.001,
):

# define exponential scheduler
scheduler = L(ExponentialParamScheduler)(
start_value=start_value,
decay=decay,
)

# wrap with warmup scheduler
return L(WarmupParamScheduler)(
scheduler=scheduler,
warmup_length=warmup_steps / num_updates,
warmup_method=warmup_method,
warmup_factor=warmup_factor,
)
187 changes: 187 additions & 0 deletions docs/source/tutorials/Customize_Training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Customize Training Settings
This document provides a brief tutorials about how to customize training components in detrex.

## Customize LR Scheduler
We've provide a series of commonly used scheduler configs in [common_schedule.py](), which is a simple wrapper of fvcore's [ParamSchduler](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.common.param_scheduler.ParamScheduler) for better usage in detrex. Here we provide examples with images to demonstrate the use of these default configurations. The users can also refer to fvcore's [documentation](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#) for more detailed API reference. Note that **all of the pre-defined scheduler functions return the scheduler config object** in detrex, you should only use it in your own config file and assign it with ``lr_multiplier`` namespace. For configuration, you can refer to [Config System](https://detrex.readthedocs.io/en/latest/tutorials/Config_System.html) for more details.

### MultiStep LR Scheduler
A modified version of multi-step scheduler based on fvcore's [MultiStepParamScheduler](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.common.param_scheduler.MultiStepParamScheduler) which has the same functionality as Pytorch's [MultiStepLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html?highlight=multistep#torch.optim.lr_scheduler.MultiStepLR).

**Example:**
```python
from detectron2.config import instantiate
from detrex.config.configs.common.common_schedule import multistep_lr_scheduler

# define `lr_multiplier` config
lr_multiplier = multistep_lr_scheduler(
values=[1.0, 0.1, 0.01],
warmup_steps=100,
num_updates=1000,
milestones=[600, 900],
warmup_method="linear",
warmup_factor=0.001,
)
```
- `values=List[float]`: List of multiplicative factors for learning rate decay.
- `warmup_steps=int`: Learning rate warmup steps.
- `num_updates=int`: The total steps for this scheduler, usually equals to ``train.max_iter``.
- `milestones=List[int]`: List of step indices. Must be increasing.
- `warmup_method=str`: Warmup method, choose from ``{"constant", "linear"}``.
- `warmup_factor=float`: The factor w.r.t the initial value of ``scheduler``.

In this example, the parameter value will increase linearly from 0.001 to 0.1 for steps 0 to 99, and will be 1.0 for steps 100 to 599, 0.1 for steps 600 to 899, and 0.01 for steps 900 to 1000. If we plot this scheduler, it will be looked like:

![](./assets/multi_step_lr_scheduler.png)

<details>
<summary> The simple code for visualization </summary>

```python
scheduler = instantiate(lr_multiplier)

x = []
y = []
for i in range(1000):
x.append(i)
y.append(scheduler(i/1000))


# Plot the line
plt.plot(x, y, color="blue", alpha=0.7, linewidth=2.3)

# Add labels and title to the plot
plt.xlabel('Iterations')
plt.ylabel('Learning Rate')
plt.title('MultiStep Scheduler')

# Save the plot image
plt.savefig('line_plot.png')
```

</details>

### Step LR Scheduler
A modified version of multi-step scheduler based on fvcore's [StepParamScheduler](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.common.param_scheduler.StepParamScheduler).

**Example:**
```python
from detectron2.config import instantiate
from detrex.config.configs.common.common_schedule import step_lr_scheduler

# define `lr_multiplier` config
lr_multiplier = step_lr_scheduler(
values=[1.0, 0.5, 0.25, 0.1],
warmup_steps=100,
num_updates=1000,
warmup_method="linear",
warmup_factor=0.001,
)
```

It will automatically divide the ``num_updates`` into **several equal intervals**, then assign the specified value in ``values`` to each interval according to the index.

In this example, the parameter value will increase linearly from 0.001 to 0.1 for steps 0 to 99, and will be 1.0 for steps 100 to 249, 0.5 for steps 250 to 499, 0.25 for steps 500 to 749 and 0.1 for steps 750 to 1000. If we plot this scheduler, it will be looked like:

![](./assets/step_lr_scheduler.png)


### Step LR Scheduler with Fixed Gamma

A modified version of step-lr with fixed gamma based on fvcore's [StepWithFixedGammaParamScheduler](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.common.param_scheduler.StepWithFixedGammaParamScheduler).

**Example:**
```python
from detectron2.config import instantiate
from detrex.config.configs.common.common_schedule import step_lr_scheduler_with_fixed_gamma

# define `lr_multiplier` config
lr_multiplier = step_lr_scheduler_with_fixed_gamma(
base_value=1.0,
gamma=0.1,
num_decays=3,
num_updates=1000,
warmup_steps=100,
warmup_method="linear",
warmup_factor=0.001,
)
```
- `base_value=float`: The base multiplicative factor.
- `num_decays=int`: The number of times the multiplicative factor decays.
- `num_updates=int`: The total steps for this scheduler, usually equals to ``train.max_iter``.

In this example, num_updates will be splited into `num_decays + 1 = 4` parts as `[0, 250), [250, 500), [500, 750), [750, 1000)`. And the parameter value will increase linearly from 0.001 to 0.1 for steps 0 to 99, and will be `1.0` for steps 100 to 249, `1.0 * 0.1 = 0.1` for steps 250 to 499, `0.1 * 0.1 = 0.01` for steps 500 to 749 and `0.01 * 0.1 = 0.001` for steps 750 to 1000. If we plot this scheduler, it will be looked like:

![](./assets/step_lr_with_fixed_gamma.png)


### Cosine LR Scheduler

A modified version of cosine-lr based on fvcore's [CosineParamScheduler](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.common.param_scheduler.CosineParamScheduler).

**Example:**
```python
from detectron2.config import instantiate
from detrex.config.configs.common.common_schedule import cosine_lr_scheduler, linear_lr_scheduler

# define `lr_multiplier` config
lr_multiplier = cosine_lr_scheduler(
start_value=1.0,
end_value=0.0001,
warmup_steps=100,
num_updates=1000,
warmup_method="linear",
warmup_factor=0.001,
)
```

In this example, the parameter will increase linearly from 0.001 to the highest value and then decrease cosinely to 0.0001 in 100 to 1000 steps. If we plot this scheduler, it will be looked like:

![](./assets/cosine_lr_scheduler.png)

### Linear LR Scheduler
A modified version of linear-lr based on fvcore's [LinearParamScheduler](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.common.param_scheduler.LinearParamScheduler).

**Example:**
```python
from detectron2.config import instantiate
from detrex.config.configs.common.common_schedule import cosine_lr_scheduler, linear_lr_scheduler

# define `lr_multiplier` config
lr_multiplier = linear_lr_scheduler(
start_value=1.0,
end_value=0.0001,
warmup_steps=100,
num_updates=1000,
warmup_method="linear",
warmup_factor=0.001,
)
```

In this example, the parameter will increase linearly from 0.001 to the highest value and then decrease linearly to 0.0001 in 100 to 1000 steps. If we plot this scheduler, it will be looked like:

![](./assets/linear_lr_scheduler.png)


### ExponentialLR Scheduler
A modified version of linear-lr based on fvcore's [ExponentialParamScheduler](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.common.param_scheduler.ExponentialParamScheduler).

**Example:**
```python
from detectron2.config import instantiate
from detrex.config.configs.common.common_schedule import exponential_lr_scheduler

# define `lr_multiplier` config
lr_multiplier = exponential_lr_scheduler(
start_value=1.0,
decay=0.02,
warmup_steps=100,
num_updates=1000,
warmup_method="linear",
warmup_factor=0.001,
)
```

In this example, the parameter will increase linearly from 0.001 to the highest value and then decrease exponentially with 0.02 ratio in 100 to 1000 steps. If we plot this scheduler, it will be looked like:

![](./assets/exponential_lr_scheduler.png)

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Tutorials
Download_Pretrained_Weights.md
Using_Pretrained_Backbone.md
Tools.md
Customize_Training.md
Model_Zoo.md
FAQs.md

0 comments on commit 40043aa

Please sign in to comment.