Skip to content

Commit

Permalink
Merge pull request #3 from kzaleskaa/feat/add-quantization
Browse files Browse the repository at this point in the history
Feat/add initial quantization
kzaleskaa authored Jun 6, 2024
2 parents 3abf2ed + 46dcad9 commit 41624e5
Showing 21 changed files with 1,421 additions and 42 deletions.
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -74,3 +74,30 @@ You can override any parameter from command line like this
```bash
python src/train.py trainer.max_epochs=20 data.batch_size=64
```

## Results for BiFPN + FFNet

The base model was trained for 25 epochs. QAT was performed for 10 epochs.

**Baseline and Fuse**

<div align=center>

| Method | test/ssim (Per tensor) | model size (MB) (Per tensor) |
| ------------ | ---------------------- | ---------------------------- |
| **baseline** | 0.778 | 3.53 |
| **fuse** | 0.778 | 3.45 |

</div>

**PTQ, QAT, and PTQ + QAT (Per tensor and Per channel)**

<div align=center>

| Method | test/ssim (Per tensor) | model size (MB) (Per tensor) | test/ssim (Per channel) | model size (MB) (Per channel) |
| ------------- | ---------------------- | ---------------------------- | ----------------------- | ----------------------------- |
| **ptq** | 0.6480 | 0.96791 | 0.6518 | 0.9679 |
| **qat** | 0.7715 | 0.96791 | 0.7627 | 0.9681 |
| **ptq + qat** | 0.7724 | 0.96899 | 0.7626 | 0.9692 |

</div>
4 changes: 2 additions & 2 deletions configs/data/depth.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_target_: src.data.depth_datamodule.DepthDataModule
data_dir: ${paths.data_dir}
batch_size: 32 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
num_workers: 4
batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
num_workers: 8
pin_memory: False
30 changes: 30 additions & 0 deletions configs/experiment/example_train_baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: depth
- override /model: depth
- override /callbacks: default
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["depth", "simple_depth_net"]

seed: 12345

trainer:
min_epochs: 10
max_epochs: 25
gradient_clip_val: 0.5

model:
optimizer:
lr: 0.002
compile: false

data:
batch_size: 64
34 changes: 34 additions & 0 deletions configs/experiment/fuse_batch_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: depth
- override /model: depth
- override /callbacks: default
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["depth", "simple_depth_net"]

seed: 12345

fuse_batch: true

trainer:
min_epochs: 10
max_epochs: 25
gradient_clip_val: 0.5

model:
optimizer:
lr: 0.002
compile: false

data:
batch_size: 64

save_path: fuse_batch.pty
35 changes: 35 additions & 0 deletions configs/experiment/ptq_qat_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: depth
- override /model: depth
- override /callbacks: default
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["depth", "simple_depth_net"]

seed: 12345

ptq: true
qat: true

trainer:
min_epochs: 10
max_epochs: 10
gradient_clip_val: 0.5

model:
optimizer:
lr: 0.002
compile: false

data:
batch_size: 64

save_path: ptq_qat_channel.pty
34 changes: 34 additions & 0 deletions configs/experiment/ptq_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: depth
- override /model: depth
- override /callbacks: default
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["depth", "simple_depth_net"]

seed: 12345

ptq: true

trainer:
min_epochs: 10
max_epochs: 25
gradient_clip_val: 0.5

model:
optimizer:
lr: 0.002
compile: false

data:
batch_size: 64

save_path: ptq_tensor.pty
34 changes: 34 additions & 0 deletions configs/experiment/qat_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: depth
- override /model: depth
- override /callbacks: default
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["depth", "simple_depth_net"]

seed: 12345

qat: true

trainer:
min_epochs: 10
max_epochs: 10
gradient_clip_val: 0.5

model:
optimizer:
lr: 0.002
compile: false

data:
batch_size: 64

save_path: qat_channel.pty
8 changes: 5 additions & 3 deletions configs/model/depth.yaml
Original file line number Diff line number Diff line change
@@ -3,18 +3,20 @@ _target_: src.models.unet_module.UNETLitModule
optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.001
lr: 1e-3
weight_decay: 0.0

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 20
threshold: 0.0001
patience: 5
threshold_mode: "abs"

net:
_target_: src.models.components.depth_net.DepthNet
_target_: src.models.components.depth_net_efficient_ffn.DepthNet

# compile model for faster training with pytorch 2.0
compile: false
33 changes: 33 additions & 0 deletions configs/quantization.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# @package _global_

defaults:
- _self_
- data: depth
- model: depth
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- paths: default
- extras: default
- hydra: default

fuse_batch: false
ptq: false
qat: false

save_path: name.pty

quantizer:
config:
asymmetric: true
backend: "qnnpack"
disable_requantization_for_cat: true
per_tensor: false
work_dir: "quant_output"

task_name: "quantization"

tags: ["dev"]

# passing checkpoint path is necessary for quantization
ckpt_path: ???
7 changes: 0 additions & 7 deletions notebooks/example_model_results.ipynb
Original file line number Diff line number Diff line change
@@ -114,13 +114,6 @@
"for i in range(5):\n",
" visualize_result(test_dataset[i][0], test_dataset[i][1], outputs[i])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
402 changes: 402 additions & 0 deletions notebooks/quantization.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -308,6 +308,15 @@
"output = bifpn([c1, c2, c3])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output[0].shape, output[1].shape, output[2].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -340,7 +349,6 @@
" x1, x2, x3 = self.encoder(x)\n",
" out = self.decoder([x1, x2, x3])\n",
" cated = self.upsample_cat(out)\n",
" print(cated.shape)\n",
" return self.final_conv(cated)"
]
},
352 changes: 352 additions & 0 deletions notebooks/unet_network_efficientnet_ffnet.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torchvision\n",
"from torch.nn import functional as F"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**EffcientNetB0 as encoder**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class EfficientNet(nn.Module):\n",
" def __init__(self):\n",
" super(EfficientNet, self).__init__()\n",
" efficientnet = torchvision.models.efficientnet_b0()\n",
" features = efficientnet.features\n",
" self.layer1 = features[:3]\n",
" self.layer2 = features[3]\n",
" self.layer3 = features[4]\n",
"\n",
" def forward(self, x):\n",
" x1 = self.layer1(x)\n",
" x2 = self.layer2(x1)\n",
" x3 = self.layer3(x2)\n",
" return x1, x2, x3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"efficientnet = EfficientNet()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_size = (3, 224, 224)\n",
"\n",
"input_tensor = torch.randn(1, *input_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output = efficientnet(input_tensor)\n",
"\n",
"output[0].shape, output[1].shape, output[2].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**FFN as decoder**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"BN_MOMENTUM = 0.1\n",
"gpu_up_kwargs = {\"mode\": \"bilinear\", \"align_corners\": True}\n",
"mobile_up_kwargs = {\"mode\": \"nearest\"}\n",
"relu_inplace = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ConvBNReLU(nn.Module):\n",
" def __init__(\n",
" self,\n",
" in_chan,\n",
" out_chan,\n",
" ks=3,\n",
" stride=1,\n",
" padding=1,\n",
" activation=nn.ReLU,\n",
" *args,\n",
" **kwargs,\n",
" ):\n",
" super(ConvBNReLU, self).__init__()\n",
" layers = [\n",
" nn.Conv2d(\n",
" in_chan,\n",
" out_chan,\n",
" kernel_size=ks,\n",
" stride=stride,\n",
" padding=padding,\n",
" bias=False,\n",
" ),\n",
" nn.BatchNorm2d(out_chan, momentum=BN_MOMENTUM),\n",
" ]\n",
" if activation:\n",
" layers.append(activation(inplace=relu_inplace))\n",
" self.layers = nn.Sequential(*layers)\n",
"\n",
" def forward(self, x):\n",
" return self.layers(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AdapterConv(nn.Module):\n",
" def __init__(self, in_channels=[256, 512, 1024, 2048], out_channels=[64, 128, 256, 512]):\n",
" super(AdapterConv, self).__init__()\n",
" assert len(in_channels) == len(\n",
" out_channels\n",
" ), \"Number of input and output branches should match\"\n",
" self.adapter_conv = nn.ModuleList()\n",
"\n",
" for k in range(len(in_channels)):\n",
" self.adapter_conv.append(\n",
" ConvBNReLU(in_channels[k], out_channels[k], ks=1, stride=1, padding=0),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" out = []\n",
" for k in range(len(self.adapter_conv)):\n",
" out.append(self.adapter_conv[k](x[k]))\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class UpsampleCat(nn.Module):\n",
" def __init__(self, upsample_kwargs=gpu_up_kwargs):\n",
" super(UpsampleCat, self).__init__()\n",
" self._up_kwargs = upsample_kwargs\n",
"\n",
" def forward(self, x):\n",
" \"\"\"Upsample and concatenate feature maps.\"\"\"\n",
" assert isinstance(x, list) or isinstance(x, tuple)\n",
" # print(self._up_kwargs)\n",
" x0 = x[0]\n",
" _, _, H, W = x0.size()\n",
" for i in range(1, len(x)):\n",
" x0 = torch.cat([x0, F.interpolate(x[i], (H, W), **self._up_kwargs)], dim=1)\n",
" return x0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class UpBranch(nn.Module):\n",
" def __init__(\n",
" self,\n",
" in_channels=[64, 128, 256],\n",
" out_channels=[128, 128, 128],\n",
" upsample_kwargs=gpu_up_kwargs,\n",
" ):\n",
" super(UpBranch, self).__init__()\n",
"\n",
" self._up_kwargs = upsample_kwargs\n",
"\n",
" self.fam_32_sm = ConvBNReLU(in_channels[2], out_channels[2], ks=3, stride=1, padding=1)\n",
" self.fam_32_up = ConvBNReLU(in_channels[2], in_channels[1], ks=1, stride=1, padding=0)\n",
" self.fam_16_sm = ConvBNReLU(in_channels[1], out_channels[0], ks=3, stride=1, padding=1)\n",
" self.fam_16_up = ConvBNReLU(in_channels[1], in_channels[0], ks=1, stride=1, padding=0)\n",
" self.fam_8_sm = ConvBNReLU(in_channels[0], out_channels[0], ks=3, stride=1, padding=1)\n",
" # self.fam_8_up = ConvBNReLU(\n",
" # in_channels[1], in_channels[0], ks=1, stride=1, padding=0\n",
" # )\n",
" # self.fam_4 = ConvBNReLU(\n",
" # in_channels[0], out_channels[0], ks=3, stride=1, padding=1\n",
" # )\n",
"\n",
" self.high_level_ch = sum(out_channels)\n",
" self.out_channels = out_channels\n",
"\n",
" def forward(self, x):\n",
"\n",
" feat8, feat16, feat32 = x\n",
"\n",
" smfeat_32 = self.fam_32_sm(feat32)\n",
" upfeat_32 = self.fam_32_up(feat32)\n",
"\n",
" _, _, H, W = feat16.size()\n",
" x = F.interpolate(upfeat_32, (H, W), **self._up_kwargs) + feat16\n",
" smfeat_16 = self.fam_16_sm(x)\n",
" upfeat_16 = self.fam_16_up(x)\n",
"\n",
" _, _, H, W = feat8.size()\n",
" x = F.interpolate(upfeat_16, (H, W), **self._up_kwargs) + feat8\n",
" smfeat_8 = self.fam_8_sm(x)\n",
" # upfeat_8 = self.fam_8_up(x)\n",
"\n",
" # _, _, H, W = feat4.size()\n",
" # smfeat_4 = self.fam_4(\n",
" # F.interpolate(upfeat_8, (H, W), **self._up_kwargs) + feat4\n",
" # )\n",
"\n",
" return smfeat_8, smfeat_16, smfeat_32"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class UpHeadA(nn.Module):\n",
" def __init__(\n",
" self,\n",
" in_chans,\n",
" base_chans=[64, 128, 256],\n",
" upsample_kwargs=gpu_up_kwargs,\n",
" ):\n",
" layers = []\n",
" super().__init__()\n",
" layers.append(AdapterConv(in_chans, base_chans))\n",
" in_chans = base_chans[:]\n",
" layers.append(UpBranch(in_chans))\n",
" layers.append(UpsampleCat(upsample_kwargs))\n",
" self.layers = nn.Sequential(*layers)\n",
"\n",
" def forward(self, x):\n",
" return self.layers(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c1 = torch.randn([1, 24, 56, 56])\n",
"c2 = torch.randn([1, 40, 28, 28])\n",
"c3 = torch.randn([1, 80, 14, 14])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"up_head_a = UpHeadA([24, 40, 80])\n",
"\n",
"out_A = up_head_a([c1, c2, c3])\n",
"print(\"output A: \", out_A.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**EffcientNetB0 + FFN**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class EfficientNetFPN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.encoder = EfficientNet()\n",
" self.decoder = UpHeadA([24, 40, 80])\n",
" self.final_conv = nn.Conv2d(in_channels=384, out_channels=1, kernel_size=3, padding=\"same\")\n",
"\n",
" def forward(self, x):\n",
" x1, x2, x3 = self.encoder(x)\n",
" x = self.decoder([x1, x2, x3])\n",
" x = self.final_conv(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"efficientnet_fpn = EfficientNetFPN()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"efficientnet_fpn(input_tensor).shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
116 changes: 94 additions & 22 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,24 +1,96 @@
# --------- pytorch --------- #
torch>=2.0.0
torchvision>=0.15.0
lightning>=2.0.0
torchmetrics>=0.11.4

# --------- hydra --------- #
hydra-core==1.3.2
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
alembic==1.13.1
antlr4-python3-runtime==4.9.3
async-timeout==4.0.3
attrs==23.2.0
autopage==0.5.2
cfgv==3.4.0
cliff==4.7.0
cmaes==0.10.0
cmd2==2.4.3
colorlog==6.8.2
distlib==0.3.8
exceptiongroup==1.2.1
filelock==3.14.0
flatbuffers==24.3.25
frozenlist==1.4.1
fsspec==2024.5.0
greenlet==3.0.3
grpcio==1.64.0
hydra-colorlog==1.2.0
hydra-core==1.3.2
hydra-optuna-sweeper==1.2.0

# --------- loggers --------- #
# wandb
# neptune-client
# mlflow
# comet-ml
# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550

# --------- others --------- #
rootutils # standardizing the project root setup
pre-commit # hooks for applying linters on commit
rich # beautiful text formatting in terminal
pytest # tests
# sh # for running bash commands in some tests (linux/macos only)
identify==2.5.36
idna==3.7
igraph==0.11.5
iniconfig==2.0.0
Jinja2==3.1.4
lightning==2.2.5
lightning-utilities==0.11.2
Mako==1.3.5
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
nodeenv==1.9.0
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.40
nvidia-nvtx-cu12==12.1.105
omegaconf==2.3.0
optuna==2.10.1
packaging==24.0
pandas==2.2.2
pbr==6.0.0
pillow==10.3.0
platformdirs==4.2.2
pluggy==1.5.0
pre-commit==3.7.1
prettytable==3.10.0
protobuf==5.27.0
Pygments==2.18.0
pyperclip==1.8.2
pytest==8.2.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
pytorch-lightning==2.2.5
pytz==2024.1
PyYAML==6.0.1
rich==13.7.1
rootutils==1.0.7
ruamel.yaml==0.18.6
ruamel.yaml.clib==0.2.8
scipy==1.13.1
six==1.16.0
SQLAlchemy==2.0.30
stevedore==5.2.0
sympy==1.12.1
tensorboard==2.16.2
tensorboard-data-server==0.7.2
texttable==1.7.0
TinyNeuralNetwork @ git+https://github.com/alibaba/TinyNeuralNetwork.git@8c1f2ce00e9584318092956dcbf99f6eb587992c
tomli==2.0.1
torch==2.3.0
torchmetrics==1.4.0.post0
torchvision==0.18.0
tqdm==4.66.4
triton==2.3.0
typing_extensions==4.12.0
tzdata==2024.1
virtualenv==20.26.2
wcwidth==0.2.13
Werkzeug==3.0.3
yarl==1.9.4
4 changes: 4 additions & 0 deletions src/data/components/nyu_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Tuple

import numpy as np
import pandas as pd
import torch
from PIL import Image
@@ -31,6 +32,9 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
img = Image.open(img_path)
mask = Image.open(mask_path)

# img = np.array(img)
mask = np.asarray(mask, np.float64)

if self.transform:
img = self.transform(img)
if self.target_transform:
10 changes: 5 additions & 5 deletions src/data/depth_datamodule.py
Original file line number Diff line number Diff line change
@@ -78,12 +78,12 @@ def __init__(
[transforms.PILToTensor(), transforms.Resize((224, 224))]
)
self.transforms_mask_train = transforms.Compose(
[transforms.PILToTensor(), BilinearInterpolation((56, 56))]
[transforms.ToTensor(), BilinearInterpolation((56, 56))]
)
self.transforms_mask = transforms.Compose(
[
transforms.PILToTensor(),
NormalizeData(10_000 * (1 / 255)),
transforms.ToTensor(),
NormalizeData(10_000.0 * (1.0 / 255.0)),
BilinearInterpolation((56, 56)),
]
)
@@ -156,7 +156,7 @@ def train_dataloader(self) -> DataLoader[Any]:
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
# persistent_workers=True,
persistent_workers=True,
)

def val_dataloader(self) -> DataLoader[Any]:
@@ -170,7 +170,7 @@ def val_dataloader(self) -> DataLoader[Any]:
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
# persistent_workers=True,
persistent_workers=True,
)

def test_dataloader(self) -> DataLoader[Any]:
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ def __init__(self, in_features: List[int] = [24, 40, 80]):
self.decoder = BiFPN(in_features)
self.upsample_4 = nn.Upsample(scale_factor=4, mode="nearest")
self.upsample_2 = nn.Upsample(scale_factor=2, mode="nearest")
self.final_conv = nn.Conv2d(in_channels=192, out_channels=1, kernel_size=3, padding="same")
self.final_conv = nn.Conv2d(in_channels=192, out_channels=1, kernel_size=3, padding=1)

def upsample_cat(self, x: List[torch.Tensor]) -> torch.Tensor:
p4, p5, p6 = x
21 changes: 21 additions & 0 deletions src/models/components/depth_net_efficient_ffn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List

import torch
import torch.nn as nn

from .efficientnet_encoder import EfficientNet
from .ffnet_decoder import UpHeadA


class DepthNet(nn.Module):
def __init__(self, in_features: List[int] = [24, 40, 80]):
super().__init__()
self.encoder = EfficientNet()
self.decoder = UpHeadA(in_features)
self.final_conv = nn.Conv2d(in_channels=384, out_channels=1, kernel_size=3, padding=1)

def forward(self, x):
x1, x2, x3 = self.encoder(x)
x = self.decoder([x1, x2, x3])
x = self.final_conv(x)
return x
135 changes: 135 additions & 0 deletions src/models/components/ffnet_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import torch
import torch.nn as nn
import torchvision
from torch.nn import functional as F

BN_MOMENTUM = 0.1
gpu_up_kwargs = {"mode": "bilinear", "align_corners": True}
mobile_up_kwargs = {"mode": "nearest"}
relu_inplace = True


class ConvBNReLU(nn.Module):
def __init__(
self,
in_chan,
out_chan,
ks=3,
stride=1,
padding=1,
activation=nn.ReLU,
*args,
**kwargs,
):
super().__init__()
layers = [
nn.Conv2d(
in_chan,
out_chan,
kernel_size=ks,
stride=stride,
padding=padding,
bias=False,
),
nn.BatchNorm2d(out_chan, momentum=BN_MOMENTUM),
]
if activation:
layers.append(activation(inplace=relu_inplace))
self.layers = nn.Sequential(*layers)

def forward(self, x):
return self.layers(x)


class AdapterConv(nn.Module):
def __init__(self, in_channels=[256, 512, 1024], out_channels=[64, 128, 256]):
super().__init__()
assert len(in_channels) == len(
out_channels
), "Number of input and output branches should match"
self.adapter_conv = nn.ModuleList()

for k in range(len(in_channels)):
self.adapter_conv.append(
ConvBNReLU(in_channels[k], out_channels[k], ks=1, stride=1, padding=0),
)

def forward(self, x):
out = []
for k in range(len(self.adapter_conv)):
out.append(self.adapter_conv[k](x[k]))
return out


class UpsampleCat(nn.Module):
def __init__(self, upsample_kwargs=gpu_up_kwargs):
super().__init__()
self._up_kwargs = upsample_kwargs

def forward(self, x):
"""Upsample and concatenate feature maps."""
assert isinstance(x, list) or isinstance(x, tuple)
# print(self._up_kwargs)
x0 = x[0]
_, _, H, W = x0.size()
for i in range(1, len(x)):
x0 = torch.cat([x0, F.interpolate(x[i], (H, W), **self._up_kwargs)], dim=1)
return x0


class UpBranch(nn.Module):
def __init__(
self,
in_channels=[64, 128, 256],
out_channels=[128, 128, 128],
upsample_kwargs=gpu_up_kwargs,
):
super().__init__()

self._up_kwargs = upsample_kwargs

self.fam_32_sm = ConvBNReLU(in_channels[2], out_channels[2], ks=3, stride=1, padding=1)
self.fam_32_up = ConvBNReLU(in_channels[2], in_channels[1], ks=1, stride=1, padding=0)
self.fam_16_sm = ConvBNReLU(in_channels[1], out_channels[0], ks=3, stride=1, padding=1)
self.fam_16_up = ConvBNReLU(in_channels[1], in_channels[0], ks=1, stride=1, padding=0)
self.fam_8_sm = ConvBNReLU(in_channels[0], out_channels[0], ks=3, stride=1, padding=1)

self.high_level_ch = sum(out_channels)
self.out_channels = out_channels

def forward(self, x):

feat8, feat16, feat32 = x

smfeat_32 = self.fam_32_sm(feat32)
upfeat_32 = self.fam_32_up(feat32)

_, _, H, W = feat16.size()
x = F.interpolate(upfeat_32, (H, W), **self._up_kwargs) + feat16
smfeat_16 = self.fam_16_sm(x)
upfeat_16 = self.fam_16_up(x)

_, _, H, W = feat8.size()
x = F.interpolate(upfeat_16, (H, W), **self._up_kwargs) + feat8
smfeat_8 = self.fam_8_sm(x)

return smfeat_8, smfeat_16, smfeat_32


class UpHeadA(nn.Module):
def __init__(
self,
in_chans,
base_chans=[64, 128, 256],
upsample_kwargs=gpu_up_kwargs,
):
layers = []
super().__init__()
layers.append(AdapterConv(in_chans, base_chans))
in_chans = base_chans[:]
layers.append(UpBranch(in_chans))
layers.append(UpsampleCat(upsample_kwargs))
self.layers = nn.Sequential(*layers)

def forward(self, x):
return self.layers(x)
2 changes: 1 addition & 1 deletion src/models/unet_module.py
Original file line number Diff line number Diff line change
@@ -125,7 +125,7 @@ def training_step(
self.train_loss(loss)
self.train_ssim(preds, targets)
self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("train/acc", self.train_ssim, on_step=False, on_epoch=True, prog_bar=True)
self.log("train/ssim", self.train_ssim, on_step=False, on_epoch=True, prog_bar=True)

# return loss or backpropagation will fail
return loss
163 changes: 163 additions & 0 deletions src/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import os
from typing import Any, Dict, List, Optional, Tuple

import hydra
import lightning as L
import rootutils
import torch
import torch.ao.quantization.quantize_fx as quantize_fx
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
from tinynn.graph.quantization.quantizer import QATQuantizer

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src.utils import (
DisplayReults,
RankedLogger,
extras,
get_metric_value,
instantiate_callbacks,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)

log = RankedLogger(__name__, rank_zero_only=True)


def calibration(model, dataloader, num_iterations):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
count = 0

with torch.no_grad():
for data in dataloader:
img, _ = data
img = img.to(device)
model(img)

count += 1
if count >= num_iterations:
break

return model


@task_wrapper
def quantization(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
assert cfg.ckpt_path

log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)

log.info(f"Instantiating model <{cfg.model._target_}>")
model_class = hydra.utils.get_class(cfg.model._target_)
model: LightningModule = model_class.load_from_checkpoint(cfg.ckpt_path)

log.info("Instantiating callbacks...")
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
callbacks.append(DisplayReults())

log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)

object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"logger": logger,
"trainer": trainer,
}

if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)

log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule)

torch.save(model.net.state_dict(), cfg.save_path)

if cfg.fuse_batch:
log.info("Fuse modules!")
model.net = quantize_fx.fuse_fx(model.net.eval())
trainer.test(model=model, datamodule=datamodule)
torch.save(model.net.state_dict(), cfg.save_path)

if cfg.ptq or cfg.qat:
quantizer = QATQuantizer(
model.net,
torch.randn(1, 3, 52, 52),
work_dir=cfg.quantizer.work_dir,
config=cfg.quantizer,
)
model.net = quantizer.quantize()

if cfg.ptq:
log.info("Post training quantization!")
model.net.apply(torch.quantization.disable_fake_quant)
model.net.apply(torch.quantization.enable_observer)

calibration(model.net, datamodule.train_dataloader(), 50)

model.net.apply(torch.quantization.disable_observer)
model.net.apply(torch.quantization.enable_fake_quant)

if cfg.qat:
log.info("Quantization awareness training!")
trainer.fit(model=model, datamodule=datamodule)

with torch.no_grad():
model.net.eval()
model.net.cpu()
quantized_model = torch.quantization.convert(model.net)
torch.save(quantized_model.state_dict(), cfg.save_path)

model.net = quantized_model

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
cfg.trainer["accelerator"] = "cpu"
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)

trainer.test(model=model, datamodule=datamodule)

return None, None


@hydra.main(version_base="1.3", config_path="../configs", config_name="quantization.yaml")
def main(cfg: DictConfig) -> None:
"""Main entry point for evaluation.
:param cfg: DictConfig configuration composed by Hydra.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
extras(cfg)

quantization(cfg)


if __name__ == "__main__":
main()

0 comments on commit 41624e5

Please sign in to comment.