From d4ac0efefcf167cea98dd29cade799526fcfd69f Mon Sep 17 00:00:00 2001 From: ceci3 Date: Tue, 19 Dec 2023 18:42:50 +0800 Subject: [PATCH 1/5] support nf4 channel wise quant & fix bug when blocksize>512 (#1817) (#1818) --- csrc/lc/dequantize_blockwise.cu | 84 ++++++++++++++++++++--- csrc/lc/quantize_blockwise.cu | 115 ++++++++++++++++++++++++-------- 2 files changed, 162 insertions(+), 37 deletions(-) diff --git a/csrc/lc/dequantize_blockwise.cu b/csrc/lc/dequantize_blockwise.cu index 8046c34ac..0bf76a163 100644 --- a/csrc/lc/dequantize_blockwise.cu +++ b/csrc/lc/dequantize_blockwise.cu @@ -201,7 +201,6 @@ template __global__ void kDequantizeBlockwise(const floa //template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(const float *code, const unsigned char * A, const float * absmax, __nv_bfloat16 *out, int blocksize, int n); - template void dequantize_blockwise(const float *code, const unsigned char *A, const float *absmax, T *out, int blocksize, int n) { int num_blocks = n/blocksize; @@ -226,6 +225,50 @@ template void dequantize_blockwise(const float *code, const unsigned //template void dequantize_blockwise<__nv_bfloat16, FP4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n); //template void dequantize_blockwise<__nv_bfloat16, NF4>(const float *code, const unsigned char *A, const float *absmax, __nv_bfloat16 *out, int blocksize, int n); +template +__global__ void kDequantizeChannelwise(const unsigned char* A, + const float *absmax, + float *out, + int n, + int cout) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + + int num = n / 2; + //int part_n = num / cout; + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { + float local_absmax = absmax[i%cout]; + int idx = 2*(i/cout)* cout + i%cout; + switch(DATA_TYPE) + { + case FP4: + out[i*2 + i%cout] = dDequantizeFP4Tree(A[i] >> 4, local_absmax); + out[i*2 + cout + i%cout] = dDequantizeFP4Tree(A[i] & 0x0F, local_absmax); + break; + case NF4: + out[idx] = dDequantizeNF4(A[i] >> 4)* local_absmax; + out[idx + cout] = dDequantizeNF4(A[i] & 0x0F)* local_absmax; + break; + } + __syncthreads(); + } +} + +template void dequantize_channelwise(const unsigned char *A, const float *absmax, T *out, int n, int cout) +{ + int max_threads = 1024; + int64_t block_size = + std::min(static_cast(n), + static_cast(max_threads/ 4)); + + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (n + block_size - 1) / block_size); + + kDequantizeChannelwise<<>>(A, absmax, out, n, cout); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + std::vector DequantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, const paddle::Tensor& absmax, int blocksize, std::string quant_type) { int64_t input_numel = input.numel(); int n = input_numel; @@ -234,23 +277,44 @@ std::vector DequantizeBlockwise(const paddle::Tensor& input, con out_shape = {input_numel * 2, 1}; n = n * 2; } + if (blocksize == -1) { + out_shape = {input.shape()[0] * 2, input.shape()[1]}; + } auto out = paddle::empty(out_shape, paddle::DataType::FLOAT32, input.place()); - if (quant_type == "8bit") - dequantize_blockwise(code.data(), input.data(), absmax.data(), out.data(), blocksize, n); - else if (quant_type == "nf4") - dequantize_blockwise(NULL, input.data(), absmax.data(), out.data(), blocksize, n); - else if (quant_type == "fp4") - dequantize_blockwise(NULL, input.data(), absmax.data(), out.data(), blocksize, n); - else - PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. "); + if (blocksize == -1) { + if (quant_type == "8bit") + PD_THROW("blocksize is -1 only support NF4 and FP4."); + else + blocksize = n / absmax.numel() * 2; + + int cout = input.shape()[1]; + if (quant_type == "nf4") + dequantize_channelwise(input.data(), absmax.data(), out.data(), n, cout); + else if (quant_type == "fp4") + dequantize_channelwise(input.data(), absmax.data(), out.data(), n, cout); + else + PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. "); + } else { + if (quant_type == "8bit") + dequantize_blockwise(code.data(), input.data(), absmax.data(), out.data(), blocksize, n); + else if (quant_type == "nf4") + dequantize_blockwise(NULL, input.data(), absmax.data(), out.data(), blocksize, n); + else if (quant_type == "fp4") + dequantize_blockwise(NULL, input.data(), absmax.data(), out.data(), blocksize, n); + else + PD_THROW("NOT supported quant type. Only 8bit, nf4, fp4 are supported. "); + } return {out}; }; std::vector> GetDequantizeBlockwiseInferShape(const std::vector& input_shape, const std::vector& code_shape, const std::vector& abs_max_shape, int blocksize, std::string quant_type){ int64_t first_shape = input_shape[0] * input_shape[1] * 2; if (quant_type != "8bit") - return {{first_shape, 1}}; + if (blocksize != -1) + return {{first_shape, 1}}; + else + return {{input_shape[0] * 2, input_shape[1]}}; else return {input_shape}; } diff --git a/csrc/lc/quantize_blockwise.cu b/csrc/lc/quantize_blockwise.cu index d4f6ff2ca..e8e55b9d8 100644 --- a/csrc/lc/quantize_blockwise.cu +++ b/csrc/lc/quantize_blockwise.cu @@ -279,6 +279,7 @@ __global__ void kQuantizeBlockwise(const float * code, const T * __restrict__ A, #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH/2; j++) { + packed_4bit = 0; packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); qvals[j] = packed_4bit; @@ -360,9 +361,39 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, NF4) +template +__global__ void kQuantizeChannelwise(const float *code, + const T* A, + unsigned char* out, + float *absmax, + int n, + int cout) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + + int num = n / 2; + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { + int idx = 2*(i/cout)* cout + i%cout; + float local_absmax = absmax[i %cout]; + float inv_local_absmax = 1.0f/local_absmax; + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case FP4: + packed_4bit |= dQuantizeFP4(((float)A[idx])*inv_local_absmax) << 4; + packed_4bit |= dQuantizeFP4(((float)A[idx+cout])*inv_local_absmax); + out[i] = packed_4bit; + break; + case NF4: + packed_4bit |= dQuantizeNF4(((float)A[idx])*inv_local_absmax) << 4; + packed_4bit |= dQuantizeNF4(((float)A[idx+cout])*inv_local_absmax); + out[i] = packed_4bit; + break; + } + } +} -template void quantize_blockwise(const float *code, const paddle::Tensor& A, float *absmax, unsigned char *out, int blocksize, int n) +template void quantize_blockwise(const float *code, const paddle::Tensor& A, paddle::Tensor& absmax, unsigned char *out, int blocksize, int n, int channelwise) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; @@ -372,22 +403,43 @@ template void quantize_blockwise(const float num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; const DataType_* A_data = reinterpret_cast(A.data()); - if(blocksize == 4096) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 2048) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 1024) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 512) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 256) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 128) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else if(blocksize == 64) - kQuantizeBlockwise<<>>(code, A_data, absmax, out, n); - else - PD_THROW("only support blocksize is [64, 128, 256, 512, 1024, 2048, 4096]."); + if (channelwise == 0) { + if(blocksize == 4096) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 2048) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 1024) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 512) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 256) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 128) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + else if(blocksize == 64) + kQuantizeBlockwise<<>>(code, A_data, absmax.data(), out, n); + } + else { + if (DATA_TYPE == General8bit) + PD_THROW("blocksize is -1 only support NF4 and FP4."); + + int cout = A.shape()[1]; + int max_threads = 1024; + + absmax = A.abs().max({0}); + + int64_t block_size = + std::min(static_cast(n), + static_cast(max_threads/ 4)); + + const int64_t max_blocks = + std::max(((max_threads - 1) / block_size + 1), static_cast(1)); + const int64_t grid_size = + std::min(max_blocks, (n + block_size - 1) / block_size); + + kQuantizeChannelwise<<>>( + code, A_data, out, absmax.data(), n, cout); + } CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -395,38 +447,44 @@ template void quantize_blockwise(const float std::vector QuantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, int blocksize, std::string quant_type) { int n = input.numel(); + int channelwise = 0; std::vector out_shape = input.shape(); if (quant_type != "8bit") { // 4bit out_shape = {(n + 1) / 2, 1}; } + if (blocksize == -1){ + blocksize = input.shape()[0]; + out_shape = {input.shape()[0]/2, input.shape()[1]}; + channelwise = 1; + } auto out = paddle::empty(out_shape, paddle::DataType::UINT8, input.place()); int64_t absmax_shape = n / blocksize; auto absmax = paddle::empty({absmax_shape}, paddle::DataType::FLOAT32, input.place()); switch(input.type()) { case paddle::DataType::FLOAT32: if (quant_type == "8bit") - quantize_blockwise(code.data(), input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(code.data(), input, absmax, out.data(), blocksize, n, channelwise); else if (quant_type == "nf4") { - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); } else if (quant_type == "fp4") - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); return {out, absmax}; case paddle::DataType::FLOAT16: if (quant_type == "8bit") - quantize_blockwise(code.data(), input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(code.data(), input, absmax, out.data(), blocksize, n, channelwise); else if (quant_type == "nf4") - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); else if (quant_type == "fp4") - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); return {out, absmax}; case paddle::DataType::BFLOAT16: if (quant_type == "8bit") - quantize_blockwise(code.data(), input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(code.data(), input, absmax, out.data(), blocksize, n, channelwise); else if (quant_type == "nf4") - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); else if (quant_type == "fp4") - quantize_blockwise(NULL, input, absmax.data(), out.data(), blocksize, n); + quantize_blockwise(NULL, input, absmax, out.data(), blocksize, n, channelwise); return {out, absmax}; default: @@ -440,7 +498,10 @@ std::vector QuantizeBlockwise(const paddle::Tensor& input, const std::vector> GetQuantizeBlockwiseInferShape(const std::vector& input_shape, const std::vector& code_shape, int blocksize, std::string quant_type){ int64_t first_shape = (input_shape[0] * input_shape[1] + 1) / 2; if (quant_type != "8bit") - return {{first_shape, 1}}; + if (blocksize != -1) + return {{first_shape, 1}}; + else + return {{input_shape[0]/2, input_shape[1]}}; else return {input_shape}; } From dcf79e930694beded6965a21220f7001551f21eb Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Tue, 26 Dec 2023 15:52:31 +0800 Subject: [PATCH 2/5] Add GroupWiseQuant & AWQ & AutoClip (#1821) --- paddleslim/quant/advanced/__init__.py | 8 +- paddleslim/quant/advanced/auto_clip.py | 155 ++++++++++++++++++ paddleslim/quant/advanced/awq_search.py | 78 +++++++++ paddleslim/quant/advanced/piecewise_search.py | 39 +++-- paddleslim/quant/advanced/smooth.py | 56 ++++--- paddleslim/quant/advanced/utils.py | 20 ++- paddleslim/quant/observers/__init__.py | 2 + paddleslim/quant/observers/groupwise.py | 112 +++++++++++++ 8 files changed, 428 insertions(+), 42 deletions(-) create mode 100644 paddleslim/quant/advanced/auto_clip.py create mode 100644 paddleslim/quant/advanced/awq_search.py create mode 100644 paddleslim/quant/observers/groupwise.py diff --git a/paddleslim/quant/advanced/__init__.py b/paddleslim/quant/advanced/__init__.py index 1f0744ecf..2e779a6e1 100644 --- a/paddleslim/quant/advanced/__init__.py +++ b/paddleslim/quant/advanced/__init__.py @@ -19,6 +19,8 @@ from . import sample from . import layerwise_quant_error from . import utils_layers +from . import awq_search +from . import auto_clip from .gptq import * from .smooth import * @@ -27,6 +29,8 @@ from .sample import * from .layerwise_quant_error import * from .utils_layers import * +from .awq_search import * +from .auto_clip import * __all__ = [] __all__ += gptq.__all__ @@ -35,4 +39,6 @@ __all__ += piecewise_search.__all__ __all__ += sample.__all__ __all__ += layerwise_quant_error.__all__ -__all__ += utils_layers.__all__ \ No newline at end of file +__all__ += utils_layers.__all__ +__all__ += awq_search.__all__ +__all__ += auto_clip.__all__ \ No newline at end of file diff --git a/paddleslim/quant/advanced/auto_clip.py b/paddleslim/quant/advanced/auto_clip.py new file mode 100644 index 000000000..696901110 --- /dev/null +++ b/paddleslim/quant/advanced/auto_clip.py @@ -0,0 +1,155 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +from .utils import fake_quant +from .metrics import mse_loss +from paddle.distributed.fleet.meta_parallel import ( + ColumnParallelLinear, + RowParallelLinear, +) +__all__ = ['AutoClip'] + +class AutoClip(nn.Layer): + """ + AutoClip from AWQ[https://arxiv.org/abs/2306.00978] + """ + def __init__( + self, + model, + weight_bits=8, + weight_quant_method='groupwise', + loss_function=mse_loss, + sample_function=None, + n_grid=20, + max_shrink=0.5, + n_sample_token=128, + group_size=-1, + ): + super(AutoClip, self).__init__() + self.model = model + self.weight_bits = weight_bits + self.weight_method = weight_quant_method + self.loss_function = loss_function + self.n_grid = n_grid + self.max_shrink = max_shrink + self.n_sample_token = n_sample_token + self.bnt = (1 << (self.weight_bits - 1)) - 1 + self.sampled_inputs = {} + self.sample_function = sample_function + self.group_size = group_size + + self._apply_hook() + + def _apply_hook(self): + self._forward_hook_list = [] + for _, sub_layer in self.model.named_sublayers(): + if type(sub_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]: + forward_pre_hook_handle = sub_layer.register_forward_pre_hook( + self._forward_pre_hook) + self._forward_hook_list.append(forward_pre_hook_handle) + + def _forward_pre_hook(self, layer, input): + self._sample_scale(input, layer.full_name()) + return input + + def _sample_scale(self, input, name): + input = input[0] if type(input) == tuple else input + input.stop_gradient = True + if name not in self.sampled_inputs: + self.sampled_inputs[name] = input + else: + if self.sample_function is not None: + self.sampled_inputs[name] = self.sample_function.sample( + input, self.sampled_inputs[name], name) + else: + self.sampled_inputs[name] = input + + + def auto_clip(self, group_size=128, oc_batch_size=1024): + """ + search clip scale for each layer and update the layer's weight + """ + for sub_name, sub_layer in self.model.named_sublayers(): + name = sub_layer.full_name() + if name not in self.sampled_inputs: + continue + print('AutoClipping', sub_name, name) + weight = sub_layer.weight.cast('float16') + weight_t = paddle.transpose(weight, perm=[1, 0]) + x = self.sampled_inputs[name].cast('float16') + x = x.reshape([-1, x.shape[-1]]) + x = x.reshape([1, x.shape[0], -1, group_size]) + x = x[:, 0::x.shape[1] // self.n_sample_token] + weight_t = weight_t.reshape([weight_t.shape[0], 1, -1, group_size]) + # fast test + # oc_batch_size = weight_t.shape[0] // 4 + oc_batch_size = oc_batch_size if weight_t.shape[0] % oc_batch_size == 0 else 128 # prevent OOM + assert weight_t.shape[0] % oc_batch_size == 0 + + w_all = weight_t + best_max_val_all = [] + + for i_b in range(weight_t.shape[0] // oc_batch_size): + w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size] + + org_max_val = w.abs().max(axis=-1, keepdim=True) # co, 1, n_group, 1 + best_max_val = org_max_val.clone() + min_errs = paddle.ones_like(org_max_val, dtype='float16') * 1e9 + org_out = (x * w).sum(axis=-1) # co, n_token, n_group + for i_s in range(int(self.max_shrink * self.n_grid)): + max_val = org_max_val * (1 - i_s / self.n_grid) + max_val_tmp = max_val + cur_w = paddle.where(w > max_val_tmp, max_val_tmp, w) + cur_w = paddle.where(cur_w < - max_val_tmp, - max_val_tmp, cur_w) + quant_dequant_weight = fake_quant(cur_w, method='abs_max', weight_bits=4) + cur_out = (x * quant_dequant_weight).sum(axis=-1) + # co, 1, n_group, 1 + tmp = (cur_out - org_out).detach().clone() + err = paddle.pow(tmp, 2).mean(axis=1).reshape(min_errs.shape) + print('block {} search s {} err {}'.format(i_b, i_s, err.mean().item())) + del cur_w, cur_out, quant_dequant_weight, tmp + paddle.device.cuda.empty_cache() + + cur_best_idx = paddle.where(err < min_errs) + if cur_best_idx[0].shape[0] != 0: + min_errs[cur_best_idx] = err[cur_best_idx] + best_max_val[cur_best_idx] = max_val[cur_best_idx] + best_max_val_all.append(best_max_val) + + del org_out, org_max_val, min_errs, best_max_val, err, cur_best_idx, max_val_tmp, max_val, w + paddle.device.cuda.empty_cache() + + best_max_val = paddle.concat(best_max_val_all, axis=0) + best_max_val = paddle.squeeze(best_max_val, axis=1) + for param in sub_layer.parameters(include_sublayers=False): + if 'w_0' in param.name: + param_tmp = param.transpose(perm=[1, 0]).cast('float16') + tmp_shape = param_tmp.shape + param_tmp = param_tmp.reshape([best_max_val.shape[0], best_max_val.shape[1], -1]) + best_max_val = paddle.tile(best_max_val, repeat_times=(1, 1, param_tmp.shape[-1])) + param_tmp = paddle.where(param_tmp > best_max_val, best_max_val, param_tmp) + param_tmp = paddle.where(param_tmp < - best_max_val, - best_max_val, param_tmp) + param_tmp = param_tmp.reshape(tmp_shape).cast(param.dtype) + param_tmp = param_tmp.transpose(perm=[1, 0]) + paddle.assign(param_tmp, output=param) + del param_tmp + paddle.device.cuda.empty_cache() + break + + del best_max_val, weight_t, x, weight, self.sampled_inputs[name], w_all, best_max_val_all + paddle.device.cuda.empty_cache() + diff --git a/paddleslim/quant/advanced/awq_search.py b/paddleslim/quant/advanced/awq_search.py new file mode 100644 index 000000000..55151c4e8 --- /dev/null +++ b/paddleslim/quant/advanced/awq_search.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import numpy as np +from .utils import compute_scales +from .metrics import mse_loss +__all__ = ['AWQSearch'] + +class AWQSearch(): + def __init__(self, + n_grid=20, + bits_length=4, + weight_quant_method='groupwise', + group_size=128, + loss_function=mse_loss): + ''' + The implementation of AutoScale from AWQ(https://arxiv.org/pdf/2306.00978.pdf). + ''' + self.n_grid = n_grid + self.bits_length = bits_length + self.weight_quant_method = weight_quant_method + self.bnt = (1 << (bits_length - 1)) - 1 + self.group_size = group_size + self.loss_function = loss_function + + def search(self, layer_name, sampled_input, act_abs_max, weight): + act = sampled_input + act.stop_gradient = True + print('[awq search] search input of %s' % layer_name) + dtype = weight.dtype + origin_out = paddle.matmul(act, weight) + best_error = float('inf') + best_ratio = -1 + best_scales = None + + for ratio in range(self.n_grid): + ratio = ratio * 1 / self.n_grid + act_abs_max_tmp = act_abs_max.detach().clone().cast('float32') + scales = paddle.clip(paddle.pow(act_abs_max_tmp, ratio), min=1e-4) + scales = scales / (scales.max() * scales.min()).sqrt() + scales = scales.cast(dtype) + new_weight = weight * scales.reshape([-1, 1]) + new_act = act / scales + quant_scale = compute_scales( + new_weight, method=self.weight_quant_method, group_size=self.group_size) + if self.weight_quant_method == 'groupwise': + quant_scale = paddle.repeat_interleave(quant_scale.cast('float32'), self.group_size, 0).cast(dtype) + quant_weight = paddle.clip( + paddle.round(new_weight / quant_scale * self.bnt), + -self.bnt - 1, self.bnt) + quant_dequant_weight = quant_weight / self.bnt * quant_scale + new_out = paddle.matmul(new_act, + quant_dequant_weight) + loss = self.loss_function(origin_out, new_out).numpy() + is_best = loss < best_error + if is_best: + print('find better ratio: {}, loss: {}'.format(ratio, loss)) + best_error = loss + best_ratio = ratio + best_scales = scales + + if best_scales is None: + best_scales = paddle.ones(scales.shape, dtype=dtype) + print('Cannot find better ratio.') + else: + print('Best ratio :{}, minimal loss : {}.'.format(best_ratio, best_error)) + return best_scales diff --git a/paddleslim/quant/advanced/piecewise_search.py b/paddleslim/quant/advanced/piecewise_search.py index 55678409b..e326f2e55 100644 --- a/paddleslim/quant/advanced/piecewise_search.py +++ b/paddleslim/quant/advanced/piecewise_search.py @@ -31,6 +31,8 @@ def __init__(self, search_scale_max=5., weight_quant_method='abs_max_channel_wise', act_quant_method='abs_max', + use_clip=False, + search_clip=False, loss_function=mse_loss): ''' PieceWiseSearch provides to search k_piece, alpha and scale. @@ -58,31 +60,36 @@ def __init__(self, self.act_quant_method = act_quant_method self.bnt = (1 << (bits_length - 1)) - 1 self.loss_function = loss_function + self.use_clip = use_clip + self.search_clip = search_clip def search(self, layer_name, sampled_input, act_abs_max, weight): act = sampled_input act.stop_gradient = True print('[smooth search] search input of %s' % layer_name) - + dtype = weight.dtype origin_out = paddle.matmul(act, weight) w_abs_max = weight.abs().max(axis=-1, keepdim=True) rw_abs_max = w_abs_max.reshape(act_abs_max.shape) - np_act_abs_max = np.array(act_abs_max) - np_rw_abs_max = np.array(rw_abs_max) - + smooth_scale_out = None global_loss = float('inf') best_scale = None - for k_piece in range(1, self.k_piece + 1): + if self.search_clip: + piece_range = [1] + list(range(1, self.k_piece + 1)) + else: + piece_range = list(range(1, self.k_piece + 1)) + + for k_idx, k_piece in enumerate(piece_range): if not self.search_piece: k_piece = self.k_piece print('Search {} Piece'.format(k_piece)) centroids, labels = k_means(act_abs_max, k_piece) piece = ['piece_{}'.format(a) for a in range(len(centroids))] for i in range(len(centroids)): - # print('search for piece {}; centroids value is {}'.format( - # piece[i], centroids[centroids.argsort()[i]].numpy())) + print('search for piece {}; centroids value is {}'.format( + piece[i], float(centroids[centroids.argsort()[i: i + 1]].cast('float32')))) alpha = self.search_alpha_min alpha_max = self.search_scale_max if self.search_scale_max is not None else self.search_alpha_max calibration_loss = float('inf') @@ -104,12 +111,16 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): alpha = round(alpha, 2) if alpha < 1: - s = (np.power(np_act_abs_max, alpha) / np.power( - np_rw_abs_max, 1. - alpha)).clip(min=1e-5) - s = paddle.to_tensor(s, dtype='float32') + act_abs_max_tmp = act_abs_max.detach().clone() + s = paddle.clip(paddle.pow(act_abs_max_tmp, alpha) / paddle.pow( + rw_abs_max, 1 - alpha), min=1e-5) + + if self.use_clip or (k_piece == 1 and k_idx == 1 and self.search_clip): + s = paddle.clip(act_abs_max_tmp / paddle.max(act_abs_max / s), min=1) + del act_abs_max_tmp smooth_scale = s * mask_for_search else: - smooth_scale = alpha * mask_for_search + smooth_scale = paddle.to_tensor(alpha, dtype=dtype) * mask_for_search if smooth_scale_out is not None: mask_for_ones_new = paddle.where( @@ -145,9 +156,10 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): calibration_loss = cur_loss final_smooth_scale = smooth_scale final_alpha = alpha + # print('Better alpha: {} loss: {}'.format(alpha, calibration_loss.cast('float32'))) - # print("Layer {} Piece {}, loss: {}, alpha : {}".format( - # layer_name, piece[i], float(calibration_loss), final_alpha)) + print("Layer {} Piece {}, loss: {}, alpha : {}".format( + layer_name, piece[i], float(calibration_loss.cast('float32')), final_alpha)) if smooth_scale_out is None: smooth_scale_out = final_smooth_scale else: @@ -160,4 +172,5 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): print('Find Better K-Piece {}'.format(k_piece)) if not self.search_piece: break + return best_scale diff --git a/paddleslim/quant/advanced/smooth.py b/paddleslim/quant/advanced/smooth.py index e715788ed..5e32435f5 100644 --- a/paddleslim/quant/advanced/smooth.py +++ b/paddleslim/quant/advanced/smooth.py @@ -26,6 +26,8 @@ def __init__( model_config, alpha=0.5, smooth_all_linears=False, + start_sample_step=10000, + smooth_method='smoothquant', sample_function=None, search_function=None, ): ''' @@ -68,6 +70,8 @@ def __init__( self.smooth_all_linears = smooth_all_linears self.sample_function = sample_function self.search_function = search_function + self.start_sample_step = start_sample_step + self.smooth_method = smooth_method self.model.eval() self.step = 0 @@ -98,7 +102,6 @@ def _get_smooth_layers(self): self.ln_linear_dict, self.linear_ln_dict = get_ln_linear_info( self.layer_order, self.norm_flag, self.linear_flag, self.fused_qkv, self.parallel_ffn, self.skip_norm_list) - assert len(self.ln_linear_dict) > 0, 'No LN/Linear pair found' for key in self.ln_linear_dict: print('smooth pair LN {} : Linear {}'.format( @@ -147,29 +150,32 @@ def _forward_pre_hook(self, layer, input): def _sample_scale(self, input, ln_name): x = input[0] if type(input) == tuple else input x.stop_gradient = True - x_abs_max = x.abs().max(axis=1, keepdim=True) - x_abs_max = x_abs_max.max(axis=0) + + if self.smooth_method == 'smoothquant': + x_abs_max = x.abs().max(axis=1, keepdim=True) + x_abs_max = x_abs_max.max(axis=0) + elif self.smooth_method == 'awq': + x_abs_max = x.abs().reshape([-1, x.shape[-1]]) + x_abs_max = x_abs_max.mean(axis=0).reshape([1, -1]) + else: + raise NotImplementedError("To be implemented") if ln_name not in self.scale_dict: self.sampled_inputs[ln_name] = x self.scale_dict[ln_name] = x_abs_max else: - if self.sample_function is not None: + if self.sample_function is not None and self.step >= self.start_sample_step: self.sampled_inputs[ln_name] = self.sample_function.sample( x, self.sampled_inputs[ln_name], ln_name) else: self.sampled_inputs[ln_name] = x - tmp1 = paddle.concat([x_abs_max, self.scale_dict[ln_name]], axis=0) - self.scale_dict[ln_name] = tmp1.max(axis=0, keepdim=True) + if self.smooth_method == 'smoothquant': + tmp1 = paddle.concat([x_abs_max, self.scale_dict[ln_name]], axis=0) + self.scale_dict[ln_name] = tmp1.max(axis=0, keepdim=True) + elif self.smooth_method == 'awq': + tmp1 = paddle.concat([x_abs_max, self.scale_dict[ln_name]], axis=0) + self.scale_dict[ln_name] = tmp1.mean(axis=0, keepdim=True) - # per step print once - if self.print_step == self.step: - print('[Smooth] Step [{}]: {}. abs_min: {}, abs_max: {}'.format( - self.step, ln_name, - float(self.scale_dict[ln_name].cast("float32").min()), - float(self.scale_dict[ln_name].cast("float32").max()))) - if ln_name == list(self.linear_ln_dict.values())[-1]: - self.print_step += 1 def update_weight(self): @@ -181,24 +187,20 @@ def update_weight(self): if type(sub_layer) == ShiftSmoothHelpLayer: ln_name = layer_name if ln_name is not None: - act_abs_max = self.scale_dict[ln_name].cast("float32") - sampled_input = self.sampled_inputs[ln_name].cast("float32") + act_abs_max = self.scale_dict[ln_name].cast("float16") + sampled_input = self.sampled_inputs[ln_name].cast("float16") for param in sub_layer.parameters(include_sublayers=False): if 'w_0' in param.name: - weight = param.cast("float32") + # weight = param.cast("float32") if self.search_function is not None: s = self.search_function.search( - layer_name, sampled_input, act_abs_max, weight) + layer_name, sampled_input, act_abs_max, param.cast("float16")) else: - w_abs_max = weight.abs().max(axis=-1, keepdim=True) + w_abs_max = param.abs().max(axis=-1, keepdim=True) rw_abs_max = w_abs_max.reshape(act_abs_max.shape) - act_abs_max_np = act_abs_max.numpy() - weight_abs_max_np = rw_abs_max.numpy() - s = ( - np.power(act_abs_max_np, self.alpha) / np.power( - weight_abs_max_np, 1 - self.alpha)).clip( - min=1e-5) - s = paddle.to_tensor(s, dtype="float32") + act_abs_max_tmp = act_abs_max.detach().clone() + s = paddle.clip(paddle.pow(act_abs_max_tmp, self.alpha) / paddle.pow( + rw_abs_max, 1 - self.alpha), min=1e-5) self.smooth_scale_dict[ln_name] = s.cast(param.dtype) break @@ -273,4 +275,4 @@ def update_weight(self): def _remove_hook(self): for hook in self._forward_hook_list: hook.remove() - self._forward_hook_list = [] + self._forward_hook_list = [] \ No newline at end of file diff --git a/paddleslim/quant/advanced/utils.py b/paddleslim/quant/advanced/utils.py index 703fc5e1c..ff77462b2 100644 --- a/paddleslim/quant/advanced/utils.py +++ b/paddleslim/quant/advanced/utils.py @@ -38,7 +38,7 @@ def k_means(weight, n_clusters, init='k-means++', max_iter=300): return paddle.to_tensor(centroids.flatten()), paddle.to_tensor(labels) -def compute_scales(x, method='abs_max'): +def compute_scales(x, method='abs_max', group_size=-1): if method == 'abs_max': quant_scale = float(paddle.max(paddle.abs(x.flatten()))) quant_scale = 1e-8 if quant_scale == 0.0 else quant_scale @@ -52,8 +52,26 @@ def compute_scales(x, method='abs_max'): 0, dtype=x.dtype), paddle.to_tensor(1e-8, dtype=x.dtype), quant_scale) + elif method == 'groupwise': + input_shape = x.shape + input_processed = x.transpose([1, 0]).reshape( + [input_shape[1], input_shape[0] // group_size, group_size]) + quant_scale = paddle.max( + paddle.abs(input_processed), axis=2) + quant_scale = paddle.where(quant_scale == paddle.to_tensor(0, dtype=x.dtype), + paddle.to_tensor(1e-8, dtype=x.dtype), quant_scale) + quant_scale = quant_scale.transpose([1, 0]) + return quant_scale +def fake_quant(x, method='abs_max', weight_bits=8, group_size=-1): + bnt = (1 << (weight_bits - 1)) - 1 + quant_scale = compute_scales(x, method=method, group_size=group_size) + quant_value = paddle.clip( + paddle.round(x / quant_scale * bnt), -bnt - 1, bnt) + quant_dequant_value = quant_value / bnt * quant_scale + return quant_dequant_value + def find_parent_layer_and_sub_name(model, name): last_idx = 0 diff --git a/paddleslim/quant/observers/__init__.py b/paddleslim/quant/observers/__init__.py index 7ab3b723e..0b7970ba8 100644 --- a/paddleslim/quant/observers/__init__.py +++ b/paddleslim/quant/observers/__init__.py @@ -20,6 +20,7 @@ from .abs_max import AbsmaxObserver from .mse_weight import MSEChannelWiseWeightObserver from .abs_max_weight import AbsMaxChannelWiseWeightObserver +from .groupwise import GroupWiseWeightObserver __all__ = [ "HistObserver", @@ -31,4 +32,5 @@ "AbsmaxObserver", "MSEChannelWiseWeightObserver", "AbsMaxChannelWiseWeightObserver", + "GroupWiseWeightObserver" ] diff --git a/paddleslim/quant/observers/groupwise.py b/paddleslim/quant/observers/groupwise.py new file mode 100644 index 000000000..1db2067c6 --- /dev/null +++ b/paddleslim/quant/observers/groupwise.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from .channel_wise import ChannelWiseObserver +from paddle.quantization.factory import ObserverFactory + + +class GroupWiseWeightObserver(ObserverFactory): + r""" + It collects channel-wise maximum absolute values of target weights. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import AbsMaxChannelWiseWeightObserver + quanter = AbsMaxChannelWiseWeightObserver() + q_config = QuantConfig(activation=None, weight=quanter) + """ + + def __init__(self, quant_bits=8, group_size=128): + super(GroupWiseWeightObserver, self).__init__( + quant_bits=quant_bits, + group_size=group_size) + + def _get_class(self): + return GroupWiseWeightObserverLayer + + +class GroupWiseWeightObserverLayer(ChannelWiseObserver): + def __init__(self, layer, quant_bits=8, group_size=128): + super(GroupWiseWeightObserverLayer, self).__init__( + layer, + quant_bits=quant_bits, + sign=True, + symmetric=True, ) + self.quant_bits = quant_bits + self.group_size = group_size + self.qmin, self.qmax = self.qmin_qmax + self._layer = layer + self._max = None + self._scale = None + self._zero_point = None + + def forward(self, inputs): + self._max = self._cal_abs_max(inputs) + return inputs + + def _cal_abs_max(self, inputs): + """ Use group_size to group the input, then use the + absmax method to calculate the scale + """ + input_shape = inputs.shape + assert self.group_size == 64 or self.group_size == 128, \ + "group_size only support 64 or 128" + assert inputs.shape[0] % self.group_size == 0, \ + "group_size must be a factor of input channels" + assert len(inputs.shape) == 2, \ + "Currently only support 2D tensor" + input_processed = inputs.transpose([1, 0]).reshape( + [input_shape[1], input_shape[0] // self.group_size, self.group_size]) + + abs_max_values = paddle.max( + paddle.abs(input_processed), axis=2).cast("float32") + # "abs_max_values < 1e-8" in bfloat16 type? + abs_max_values = paddle.where(abs_max_values == np.float32(0), + np.float32(1e-8), abs_max_values) + abs_max_values = abs_max_values.transpose([1, 0]) + return abs_max_values + + def min_value(self) -> float: + return 0. + + def max_value(self) -> float: + return self._max + + def cal_thresholds(self): + """ Compute thresholds for MAX function. + """ + if self._scale is None: + self._scale = self._max + self._zero_point = paddle.zeros_like(self._scale) + + def scales(self): + """ Return output scales. + """ + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + """ Return output zero points. + """ + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point From 521157e390aa8bca62953e251495257e334a9477 Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Thu, 28 Dec 2023 21:36:33 +0800 Subject: [PATCH 3/5] [Cherry-Pick]Cp fit paddle26 (#1823) --- paddleslim/quant/advanced/gptq.py | 19 +++++++++++++------ paddleslim/quant/advanced/piecewise_search.py | 3 +++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/paddleslim/quant/advanced/gptq.py b/paddleslim/quant/advanced/gptq.py index 96566858f..5ae47205c 100644 --- a/paddleslim/quant/advanced/gptq.py +++ b/paddleslim/quant/advanced/gptq.py @@ -106,8 +106,9 @@ def fasterquant(self, H = self.hessian del self.hessian dead = paddle.where(paddle.diag(H) == 0) - H[dead, dead] = 1 - W[:, dead] = 0 + if dead[0].shape[0] != 0: + H[dead, dead] = 1 + W[:, dead] = 0 del dead if actorder: perm = paddle.argsort(paddle.diag(H), descending=True) @@ -122,9 +123,15 @@ def fasterquant(self, damp = percdamp * paddle.mean(paddle.diag(H)) diag = paddle.arange(self.columns) H[diag, diag] += damp - - H = paddle.inverse(H) - H = paddle.linalg.cholesky(H, upper=True) + try: + H = paddle.inverse(H) + H = paddle.linalg.cholesky(H, upper=True) + except: + print('We skip GPTQ this layer now.') + print( + 'If you want GPTQ this layer, please try setting damp_percent larger or increasing the number of samples.' + ) + return Hinv = H for i1 in range(0, self.columns, blocksize): @@ -182,4 +189,4 @@ def fasterquant(self, self.quantized = True del H, Q, Hinv, W, Losses - paddle.device.cuda.empty_cache() + paddle.device.cuda.empty_cache() \ No newline at end of file diff --git a/paddleslim/quant/advanced/piecewise_search.py b/paddleslim/quant/advanced/piecewise_search.py index e326f2e55..a95b2a1c7 100644 --- a/paddleslim/quant/advanced/piecewise_search.py +++ b/paddleslim/quant/advanced/piecewise_search.py @@ -97,6 +97,8 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): mask_for_search = paddle.where(labels == centroids.argsort()[i], 1., 0.) mask_for_ones = paddle.where(mask_for_search == 0., 1., 0.) + mask_for_search = mask_for_search.cast(dtype) + mask_for_ones = mask_for_ones.cast(dtype) while alpha <= alpha_max: if alpha < 1: @@ -125,6 +127,7 @@ def search(self, layer_name, sampled_input, act_abs_max, weight): if smooth_scale_out is not None: mask_for_ones_new = paddle.where( smooth_scale_out == 0., 1., 0.) + mask_for_ones_new = mask_for_ones_new.cast(dtype) mask_for_ones *= mask_for_ones_new smooth_scale_ = smooth_scale_out + smooth_scale smooth_scale_tmp = smooth_scale_ + mask_for_ones From e72073a24ebc9d0929e2b100b92385e161c0b484 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Fri, 2 Feb 2024 03:25:39 +0000 Subject: [PATCH 4/5] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=9B=AE=E6=A0=87?= =?UTF-8?q?=E6=A3=80=E6=B5=8B=E6=A8=A1=E5=9E=8B=E7=A6=BB=E7=BA=BF=E9=87=8F?= =?UTF-8?q?=E5=8C=96=E7=A4=BA=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../detection/README.md | 113 ++++++++++++++++-- .../detection/configs/picodet_s_ptq.yaml | 4 +- .../detection/configs/ppyoloe_s_ptq.yaml | 5 +- 3 files changed, 106 insertions(+), 16 deletions(-) diff --git a/example/post_training_quantization/detection/README.md b/example/post_training_quantization/detection/README.md index f590606dd..41e217a15 100644 --- a/example/post_training_quantization/detection/README.md +++ b/example/post_training_quantization/detection/README.md @@ -22,9 +22,6 @@ | 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | 配置文件 | Inference模型 | | :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | -| PP-YOLOE-s | Base模型 | 640*640 | 43.1 | 11.2ms | 7.7ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_crn_s_300e_coco.tar) | -| PP-YOLOE-s | 离线量化 | 640*640 | 42.6 | - | - | 6.7ms | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/ppyoloe_s_ptq.tar) | -| | | | | | | | | | | PicoDet-s | Base模型 | 416*416 | 32.5 | - | - | - | - | [Model](https://paddledet.bj.bcebos.com/deploy/Inference/picodet_s_416_coco_lcnet.tar) | | PicoDet-s | 离线量化(量化分析前) | 416*416 | 0.0 | - | - | - | - | - | | PicoDet-s | 离线量化(量化分析后) | 416*416 | 24.9 | - | - | - | - | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_ptq.tar) | @@ -35,22 +32,24 @@ ## 3. 离线量化流程 #### 3.1 准备环境 -- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) -- PaddleSlim >= 2.3 +- PaddlePaddle 2.5 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) +- PaddleSlim 2.5 - PaddleDet >= 2.4 - opencv-python 安装paddlepaddle: ```shell # CPU -pip install paddlepaddle +python -m pip install paddlepaddle==2.5.0 -i https://pypi.tuna.tsinghua.edu.cn/simple # GPU -pip install paddlepaddle-gpu +python -m pip install paddlepaddle-gpu==2.5.0.post116 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html ``` 安装paddleslim: +注意,需要修改setup.py中slim_version='2.5',否则会安装最新版本的PaddleSlim。 ```shell -pip install paddleslim +git clone -b release/2.5 https://github.com/PaddlePaddle/PaddleSlim.git & cd PaddleSlim +python setup.py install ``` 安装paddledet: @@ -66,7 +65,7 @@ pip install paddledet 如果数据集为非COCO格式数据,请修改[configs](./configs)中reader配置文件中的Dataset字段。 -以PP-YOLOE模型为例,如果已经准备好数据集,请直接修改[./configs/ppyoloe_s_ptq.yml]中`EvalDataset`的`dataset_dir`字段为自己数据集路径即可。 +以PP-YOLOE模型为例,如果已经准备好数据集,请直接修改[./configs/ppyoloe_s_ptq.yml]中`EvalDataset`和`TrainDataset`的`dataset_dir`字段为自己数据集路径即可。 #### 3.3 准备预测模型 @@ -125,12 +124,17 @@ python post_quant.py --config_path=./configs/picodet_s_ptq.yaml --save_dir=./pic export CUDA_VISIBLE_DEVICES=0 python eval.py --config_path=./configs/ppyoloe_s_ptq.yaml ``` +这个,测试不出来模型精度,因为[ppyoloe_s_ptq.yaml](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/example/post_training_quantization/detection/configs/ppyoloe_s_ptq.yaml)的model_dir是没有NMS的,所以不打印精度 +``` +export CUDA_VISIBLE_DEVICES=0 +python eval.py --config_path=./configs/picodet_s_ptq.yaml +``` **注意**: - 要测试的模型路径可以在配置文件中`model_dir`字段下进行修改。 #### 3.6 提高离线量化精度 -本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisPTQ```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。```paddleslim.quant.AnalysisPTQ```详解见[AnalysisPTQ.md](../../../docs/zh_cn/tutorials/quant/AnalysisPTQ.md)。 +本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisPTQ```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。```paddleslim.quant.AnalysisPTQ```详解见[离线量化](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/tutorials/quant/post_training_quantization.md)。 经过多个实验,包括尝试多种激活算法(avg,KL等)、weight的量化方式(abs_max,channel_wise_abs_max),对PicoDet-s进行离线量化后精度均为0,以PicoDet-s为例,量化分析工具具体使用方法如下: @@ -167,10 +171,97 @@ python post_quant.py --config_path=./configs/picodet_s_analyzed_ptq.yaml --save_ 注:分析之后若需要直接产出符合目标精度的量化模型,demo代码不会使用少量数据集验证,会自动使用全量验证数据。 - ## 4.预测部署 预测部署可参考[Detection模型自动压缩示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/auto_compression/detection) +- TensorRT预测 +- 把[picodet_reader.yml](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/example/auto_compression/detection/configs/picodet_reader.yml)中的dataset_dir改成你环境下的数据集路径 + +```shell +python paddle_inference_eval.py \ + --model_path=picodet_s_416_coco_lcnet \ + --reader_config=configs/picodet_reader.yml \ + --use_trt=True \ + --precision=fp16 \ + --include_nms=True \ + --benchmark=True +``` +量化分析前: +```shell +python paddle_inference_eval.py \ + --model_path=picodet_s_ptq \ + --reader_config=configs/picodet_reader.yml \ + --use_trt=True \ + --precision= \ + --include_nms=True \ + --benchmark=True +``` +量化分析后: +```shell +python paddle_inference_eval.py \ + --model_path=picodet_s_analyzed_ptq_out \ + --reader_config=configs/picodet_reader.yml \ + --use_trt=True \ + --precision=int8 \ + --include_nms=True \ + --benchmark=True +``` +#### 4.1 C++部署 +请参考[YOLOv3推理](https://github.com/PaddlePaddle/Paddle-Inference-Demo/tree/master/c%2B%2B/gpu/yolov3) +编译样例 +- 文件yolov3_test.cc改成PicoDet-s.cc,为预测的样例程序(程序中的输入为固定值,如果您有opencv或其他方式进行数据读取的需求,需要对程序进行一定的修改)。 +- 脚本compile.sh包含了第三方库、预编译库的信息配置。 +- 脚本run.sh为一键运行脚本。 +编译前,需要根据自己的环境修改compile.sh中的相关代码配置依赖库: +```shell +# 编译的 demo 名称 +DEMO_NAME=picoDet-s + +# 根据预编译库中的version.txt信息判断是否将以下三个标记打开 +WITH_MKL=ON +WITH_GPU=ON +USE_TENSORRT=ON + +# 配置预测库的根目录 +LIB_DIR=${work_path}/../lib/paddle_inference + +# 如果上述的WITH_GPU 或 USE_TENSORRT设为ON,请设置对应的CUDA, CUDNN, TENSORRT的路径。 +CUDNN_LIB=/usr/lib/x86_64-linux-gnu/ +CUDA_LIB=/usr/local/cuda/lib64 +TENSORRT_ROOT=/usr/local/TensorRT-7.1.3.4 +``` +运行bash compile.sh编译样例 +运行样例: +- 使用原生GPU运行样例 +```shell +./build/picodet-s --model_file picodet_s_416_coco_lenet/model.pdmodel --params_file picodet_s_416_coco_lenet/model.pdiparams +``` +- 使用Trt FP32运行样例 +```shell +./build/picodet-s --model_file picodet_s_416_coco_lenet/model.pdmodel --params_file picodet_s_416_coco_lenet/model.pdiparams --run_mode=trt_fp32 +``` + +- 使用Trt FP16运行样例 +```shell +./build/picodet-s --model_file picodet_s_416_coco_lenet/model.pdmodel --params_file picodet_s_416_coco_lenet/model.pdiparams --run_mode=trt_fp16 +``` +- 使用Trt INT8运行样例 +在使用Trt Int8运行样例时,相同的运行命令需要执行两次。 +生成量化校准表 +```shell +./build/picodet-s --model_file picodet_s_416_coco_lcnet/model.pdmodel --params_file picodet_s_416_coco_lcnet/model.pdiparams --run_mode=trt_int8 +``` +执行后,模型文件夹Picodet下的_opt_cache文件夹下会多出一个名字为trt_calib_*的文件,即校准表。 +加载校准表执行预测 +```shell +./build/picodet-s --model_file picodet_s_416_coco_lcnet/model.pdmodel --params_file picodet_s_416_coco_lcnet/model.pdiparams --run_mode=trt_int8 +``` +- 使用Trt dynamic shape运行样例(以FP32为例) +```shell +./build/picodet-s --model_file picodet_s_416_coco_lcnet/model.pdmodel --params_file picodet_s_416_coco_lcnet/model.pdiparams --run_mode=trt_fp32 --use_dynamic_shape=1 +``` + + ## 5.FAQ - 如果想对模型进行自动压缩,可进入[Detection模型自动压缩示例](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/auto_compression/detection)中进行实验。 diff --git a/example/post_training_quantization/detection/configs/picodet_s_ptq.yaml b/example/post_training_quantization/detection/configs/picodet_s_ptq.yaml index 005c0d46c..1429a9d88 100644 --- a/example/post_training_quantization/detection/configs/picodet_s_ptq.yaml +++ b/example/post_training_quantization/detection/configs/picodet_s_ptq.yaml @@ -7,6 +7,7 @@ skip_tensor_list: None metric: COCO num_classes: 80 + # Datset configuration TrainDataset: !COCODataSet @@ -34,5 +35,4 @@ EvalReader: - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} - batch_size: 32 - + batch_size: 16 \ No newline at end of file diff --git a/example/post_training_quantization/detection/configs/ppyoloe_s_ptq.yaml b/example/post_training_quantization/detection/configs/ppyoloe_s_ptq.yaml index 3c8752652..d1f5340e0 100644 --- a/example/post_training_quantization/detection/configs/ppyoloe_s_ptq.yaml +++ b/example/post_training_quantization/detection/configs/ppyoloe_s_ptq.yaml @@ -1,4 +1,4 @@ -input_list: ['image'] +input_list: ['image', 'scale_factor'] arch: PPYOLOE # When export exclude_nms=True, need set arch: PPYOLOE model_dir: ./ppyoloe_crn_s_300e_coco model_filename: model.pdmodel @@ -6,7 +6,6 @@ params_filename: model.pdiparams metric: COCO num_classes: 80 - # Datset configuration TrainDataset: !COCODataSet @@ -29,4 +28,4 @@ EvalReader: - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - Permute: {} - batch_size: 32 \ No newline at end of file + batch_size: 16 \ No newline at end of file From 6bbc9c5b3785af2f5ffd2e2e3beaae82d3495248 Mon Sep 17 00:00:00 2001 From: lizexu <2694294196@qq.com> Date: Fri, 23 Feb 2024 08:19:56 +0000 Subject: [PATCH 5/5] =?UTF-8?q?YOLO=E7=B3=BB=E5=88=97=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=87=AA=E5=8A=A8=E5=8E=8B=E7=BC=A9=E7=A4=BA=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pytorch_yolo_series/README.md | 159 ++++++--- .../cpp_infer/CMakeLists.txt | 315 ++++++++++++------ .../pytorch_yolo_series/cpp_infer/README.md | 49 +-- .../pytorch_yolo_series/cpp_infer/compile.sh | 4 +- .../pytorch_yolo_series/cpp_infer/trt_run.cc | 145 ++++---- .../paddle_inference_eval.py | 18 +- .../pytorch_yolo_series/README.md | 37 +- 7 files changed, 441 insertions(+), 286 deletions(-) diff --git a/example/auto_compression/pytorch_yolo_series/README.md b/example/auto_compression/pytorch_yolo_series/README.md index 87841ddb6..1b580977d 100644 --- a/example/auto_compression/pytorch_yolo_series/README.md +++ b/example/auto_compression/pytorch_yolo_series/README.md @@ -19,47 +19,50 @@ | 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 模型体积 | 预测时延FP32
|预测时延FP16
| 预测时延INT8
| 内存占用 | 显存占用 | 配置文件 | Inference模型 | |:--------------|:-------- |:--------: |:-----------------------:|:------:| :----------------: | :----------------: |:----------------: | :----------------: | :---------------: |:------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| -| YOLOv5s | Base模型 | 640*640 | 37.4 | 28.1MB | 6.87ms | 3.51ms | - | 1718MB | 705MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) | -| YOLOv5s | 离线量化 | 640*640 | 36.0 | 7.4MB | - | - | 3.17ms | 736MB | 315MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | -| YOLOv5s | ACT量化训练 | 640*640 | **36.9** | 7.4MB | - | - | **3.17ms** | 736MB | 315MB | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant_onnx.tar) | +| YOLOv5s | Base模型 | 640*640 | 37.5 | 28.1MB | 14.4ms | 6.9ms | - | 2637MB | 1143MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx) | +| YOLOv5s | avg离线量化 | 640*640 | 36.7 | 7.5MB | - | - | 6.4ms | 2669MB | 1089MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | +| YOLOv5s | 量化蒸馏训练 | 640*640 | **36.8** | 7.5MB | - | - | **6.8ms** | 2593MB | 1083MB | [config](./configs/yolov5s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant_onnx.tar) | | | | | | | | | | | -| YOLOv6s | Base模型 | 640*640 | 42.4 | 65.9MB | 9.18ms | 3.58ms | - | 1208MB | 555MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) | -| YOLOv6s | KL离线量化 | 640*640 | 30.3 | 16.8MB | - | - | 2.81ms | 736MB | 315MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | -| YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | 16.8MB | - | - | **2.81ms** | 736MB | 315MB | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant_onnx.tar) | +| YOLOv6s | Base模型 | 640*640 | 42.5 | 66MB | 18.13ms | 7.1ms | - | 2660MB | 1183MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx) | +| YOLOv6s | KL离线量化 | 640*640 | 34,0 | 17MB | - | - | 4.9ms | 2570MB | 1085MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | +| YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | 17MB | - | - | **4.9ms** | 2532MB | 1085MB | [config](./configs/yolov6s_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant_onnx.tar) | | | | | | | | | | | -| YOLOv6s_v2 | Base模型 | 640*640 | 43.4 | 67.4MB | 9.18ms | 3.58ms | - | 1208MB | 555MB | - | [Model](https://github.com/meituan/YOLOv6/releases/download/0.2.0/yolov6s.onnx) | -| YOLOv6s_v2 | 量化蒸馏训练 | 640*640 | **43.0** | 16.8MB | - | - | **2.81ms** | 736MB | 315MB | [config](./configs/yolov6s_v2_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_v2_0_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_v2_0_quant_onnx.tar) | -| | | | | | | | | | -| YOLOv7 | Base模型 | 640*640 | 51.1 | 141MB | 26.76ms | 8.16ms | - | 1722MB | 917MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7.onnx) | -| YOLOv7 | 离线量化 | 640*640 | 50.2 | 36MB | - | - | 5.19ms | 827MB | 363MB | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/example/post_training_quantization/pytorch_yolo_series) | - | -| YOLOv7 | ACT量化训练 | 640*640 | **50.9** | 36MB | - | - | **5.19ms** | 827MB | 363MB | [config](./configs/yolov7_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_quant_onnx.tar) | -| | | | | | | | | | -| YOLOv7-Tiny | Base模型 | 640*640 | 37.3 | 24MB | 5.06ms | 2.32ms | - | 738MB | 349MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx) | -| YOLOv7-Tiny | 离线量化 | 640*640 | 35.8 | 6.1MB | - | - | 1.68ms | 729MB | 315MB | - | - | -| YOLOv7-Tiny | ACT量化训练 | 640*640 | **37.0** | 6.1MB | - | - | **1.68ms** | 729MB | 315MB | [config](./configs/yolov7_tiny_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant_onnx.tar) | +| YOLOv7-Tiny | Base模型 | 640*640 | 37.2 | 24MB | 13.2ms | 8.1ms | - | 2466MB | 1133MB | - | [Model](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx) | +| YOLOv7-Tiny | 量化蒸馏训练 | 640*640 | **36.8** | 6.2MB | - | - | **6.6ms** | 2547MB | 1085MB | [config](./configs/yolov7_tiny_qat_dis.yaml) | [Infer Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant.tar) | [ONNX Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov7_tiny_quant_onnx.tar) | 说明: -- mAP的指标均在COCO val2017数据集中评测得到。 -- YOLOv7模型在Tesla T4的GPU环境下开启TensorRT 8.4.1,batch_size=1, 测试脚本是[cpp_infer](./cpp_infer)。 +- mAP的指标均在COCO val2017数据集中评测得到,IoU=0.5:0.95。 +- 测速环境:Tesla T4,TensorRT 8.6.1,CUDA 11.2,batch_size=1,cudnn 8.2.0 Intel(R)Xeon(R)Gold 6271C CPU ## 3. 自动压缩流程 ### 3.1 准备环境 -- PaddlePaddle >= 2.4版本 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)根据相应环境的安装指令进行安装) -- PaddleSlim >= 2.4版本 +- PaddlePaddle 2.6 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)根据相应环境的安装指令进行安装) +- PaddleSlim 2.6 +- PaddleDet >=2.4 (1)安装paddlepaddle ```shell # CPU -pip install paddlepaddle==2.4.1 +python -m pip install paddlepaddle==2.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple # GPU 以Ubuntu、CUDA 11.2为例 -python -m pip install paddlepaddle-gpu==2.4.1.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +python -m pip install paddlepaddle-gpu==2.6.0.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html ``` -(2)安装paddleslim>=2.4 +(2)安装paddleslim 2.6 ```shell pip install paddleslim ``` +(3) 安装paddledet +```shell +pip install paddledet +``` +注:安装PaddleDet的目的是为了直接使用PaddleDetection中的Dataloader组件。 + +(4)安装x2paddle +```shell +pip install X2Paddle==1.3.9 +``` #### 版本对齐 @@ -136,6 +139,12 @@ pip install paddleslim **注意**:目前ACT支持**不带NMS**模型,使用如上命令导出即可。也可以直接下载我们已经准备好的[yolov7.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov7-tiny.onnx)。 + 将ONNX模型转换为Paddle模型,举例: + 使用命令行将Yolov6s.onnx转换为Paddle模型 + ```shell + x2paddle --framework=onnx --model=yolov6s.onnx --save_dir=yolov6_model + ``` + ### 3.4 自动压缩并产出模型 蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。 @@ -145,13 +154,14 @@ pip install paddleslim - 单卡训练: ``` export CUDA_VISIBLE_DEVICES=0 -python run.py --config_path=./configs/yolov7_tiny_qat_dis.yaml --save_dir='./output/' +python run.py --config_path=./configs/yolov7_tiny_qat_dis.yaml --save_dir='./yolov7-quantAware/' ``` - 多卡训练: ``` -CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \ - --config_path=./configs/yolov7_tiny_qat_dis.yaml --save_dir='./output/' +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \ + --config_path=./configs/yolov6s_qat_dis.yaml --save_dir='./yolov6s_quantaware/' ``` @@ -177,42 +187,103 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log - | model_path | inference 模型文件所在目录,该目录下需要有文件 model.pdmodel 和 model.pdiparams 两个文件 | | dataset_dir | eval时数据验证集路径, 默认`dataset/coco` | | image_file | 如果只测试单张图片效果,直接根据image_file指定图片路径 | +| val_image_dir | coco数据集中验证图像的目录名,默认为val2017 | +| val_anno_path | 指定COCO数据集的注释(annotation)文件路径,这是包含验证集标注信息的JSON文件,默认为annotations/instances_val2017.json | +| benchmark | 指定是否运行性能基准测试。如果设置为True,程序将会进行性能测试 | | device | 使用GPU或者CPU预测,可选CPU/GPU | | use_trt | 是否使用 TesorRT 预测引擎 | | use_mkldnn | 是否启用```MKL-DNN```加速库,注意```use_mkldnn```与```use_gpu```同时为```True```时,将忽略```enable_mkldnn```,而使用```GPU```预测 | +| use_dynamic_shape | 是否使用动态形状(dynamic_shape)功能 | | cpu_threads | CPU预测时,使用CPU线程数量,默认10 | | precision | 预测精度,包括`fp32/fp16/int8` | +| arch | 指定所使用的模型架构的名称,例如YOLOv5 | +| img_shape | 指定模型输入的图像尺寸 | +| batch_size | 指定模型输入的批处理大小 | - TensorRT Python部署: + Paddle-TensorRT Python部署 首先安装带有TensorRT的[Paddle安装包](https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python)。 然后使用[paddle_inference_eval.py](./paddle_inference_eval.py)进行部署: +- YOLOv5: +```shell +python paddle_inference_eval.py \ + --model_path=yolov5_model/inference_model \ + --dataset_dir=/datasets/coco \ + --use_trt=True \ + --precision=fp32 \ + --arch=YOLOv5 +``` +```shell +python paddle_inference_eval.py \ + --model_path=yolov5s_quantaware \ + --dataset_dir=/datasets/coco \ + --use_trt=True \ + --precision=int8 \ + --arch=YOLOv5 +``` +- YOLOv6: +```shell +python paddle_inference_eval.py \ + --model_path=yolov6_model/inference_model \ + --dataset_dir=/datasets/coco \ + --use_trt=True \ + --precision=fp32 \ + --arch=YOLOv6 +``` +```shell +python paddle_inference_eval.py \ + --model_path=yolov6s_quantaware \ + --dataset_dir=/datasets/coco \ + --use_trt=True \ + --precision=int8 \ + --arch=YOLOv6 +``` +- YOLOv7: ```shell python paddle_inference_eval.py \ - --model_path=output \ - --reader_config=configs/yoloe_reader.yml \ + --model_path=yolov7-tiny/inference_model \ + --dataset_dir=/datasets/coco \ --use_trt=True \ - --precision=int8 + --precision=fp32 \ + --arch=YOLOv7 ``` +```shell +python paddle_inference_eval.py \ + --model_path=yolov7-quantAware \ + --dataset_dir=/datasets/coco \ + --use_trt=True \ + --precision=int8 \ + --arch=YOLOv7 - MKLDNN预测: ```shell python paddle_inference_eval.py \ - --model_path=output \ - --reader_config=configs/yoloe_reader.yml \ + --model_path=yolov5_model/inference_model \ + --dataset_dir=/datasets/coco \ --device=CPU \ --use_mkldnn=True \ --cpu_threads=10 \ - --precision=int8 + --precision=fp32 \ + --arch=YOLOv5 +``` +- 原生GPU推理 + +```shell +python paddle_inference_eval.py \ + --model_path=yolov5_model/inference_model \ + --dataset_dir=/datasets/coco \ + --device=GPU \ + --precision=fp32 \ + --arch=YOLOv5 ``` - 测试单张图片 ```shell -python paddle_inference_eval.py --model_path=output --image_file=images/000000570688.jpg --use_trt=True --precision=int8 +python paddle_inference_eval.py --model_path=yolov5_model/inference_model --image_file=images/000000570688.jpg --use_trt=True --precision=fp32 --arch=YOLOv5 ``` - C++部署 @@ -222,25 +293,7 @@ python paddle_inference_eval.py --model_path=output --image_file=images/00000057 # 编译 bash compile.sh # 执行 -./build/trt_run --model_file yolov7_quant/model.pdmodel --params_file yolov7_quant/model.pdiparams --run_mode=trt_int8 -``` - -### 导出至ONNX使用TensorRT部署 - -加载`quant_model.onnx`和`calibration.cache`,可以直接使用TensorRT测试脚本进行验证,详细代码可参考[TensorRT部署](./TensorRT) - -- python测试: -```shell -cd TensorRT -python trt_eval.py --onnx_model_file=output/ONNX/quant_model.onnx \ - --calibration_file=output/ONNX/calibration.cache \ - --image_file=../images/000000570688.jpg \ - --precision_mode=int8 -``` - -- 速度测试 -```shell -trtexec --onnx=output/ONNX/quant_model.onnx --avgRuns=1000 --workspace=1024 --calib=output/ONNX/calibration.cache --int8 +./build/trt_run --model_file yolov7-quantAware/model.pdmodel --params_file yolov7-quantAware/model.pdiparams --run_mode=trt_int8 ``` ## 5.FAQ diff --git a/example/auto_compression/pytorch_yolo_series/cpp_infer/CMakeLists.txt b/example/auto_compression/pytorch_yolo_series/cpp_infer/CMakeLists.txt index d5307c657..ec317d6cf 100644 --- a/example/auto_compression/pytorch_yolo_series/cpp_infer/CMakeLists.txt +++ b/example/auto_compression/pytorch_yolo_series/cpp_infer/CMakeLists.txt @@ -1,43 +1,47 @@ cmake_minimum_required(VERSION 3.0) project(cpp_inference_demo CXX C) -option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) -option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) -option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) -option(USE_TENSORRT "Compile demo with TensorRT." OFF) +option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) +option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) +option(WITH_STATIC_LIB + "Compile demo with static/shared library, default don't use static." OFF) +option(USE_TENSORRT "Compile demo with TensorRT." OFF) +option(WITH_SHARED_PHI "Compile demo with phi shared lib" ON) option(WITH_ROCM "Compile demo with rocm." OFF) -option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF) +option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF) option(WITH_ARM "Compile demo with ARM" OFF) option(WITH_MIPS "Compile demo with MIPS" OFF) -option(WITH_SW "Compile demo with SW" OFF) -option(WITH_XPU "Compile demow ith xpu" OFF) -option(WITH_NPU "Compile demow ith npu" OFF) +option(WITH_LOONGARCH "Compile demo with LOONGARCH" OFF) +option(WITH_SW "Compile demo with SW" OFF) +option(WITH_XPU "Compile demo with xpu" OFF) +option(WITH_NPU "Compile demo with npu" OFF) if(NOT WITH_STATIC_LIB) add_definitions("-DPADDLE_WITH_SHARED_LIB") else() - # PD_INFER_DECL is mainly used to set the dllimport/dllexport attribute in dynamic library mode. + # PD_INFER_DECL is mainly used to set the dllimport/dllexport attribute in dynamic library mode. # Set it to empty in static library mode to avoid compilation issues. add_definitions("/DPD_INFER_DECL=") endif() macro(safe_set_static_flag) - foreach(flag_var - CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE - CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) - if(${flag_var} MATCHES "/MD") - string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") - endif(${flag_var} MATCHES "/MD") - endforeach(flag_var) + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif() + endforeach() endmacro() if(NOT DEFINED PADDLE_LIB) - message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") + message( + FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") endif() if(NOT DEFINED DEMO_NAME) message(FATAL_ERROR "please set DEMO_NAME with -DDEMO_NAME=demo_name") endif() -include_directories("${PADDLE_LIB}/") +include_directories("${PADDLE_LIB}/paddle/include") set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/") include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include") include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include") @@ -56,19 +60,23 @@ link_directories("${PADDLE_LIB}/paddle/lib") link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib") link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib") -if (WIN32) +if(WIN32) add_definitions("/DGOOGLE_GLOG_DLL_DECL=") option(MSVC_STATIC_CRT "use static C Runtime library by default" ON) - if (MSVC_STATIC_CRT) - if (WITH_MKL) + if(MSVC_STATIC_CRT) + if(WITH_MKL) set(FLAG_OPENMP "/openmp") endif() - set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}") - set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}") - set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}") - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}") + set(CMAKE_C_FLAGS_DEBUG + "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}") + set(CMAKE_C_FLAGS_RELEASE + "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}") + set(CMAKE_CXX_FLAGS_DEBUG + "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}") + set(CMAKE_CXX_FLAGS_RELEASE + "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}") safe_set_static_flag() - if (WITH_STATIC_LIB) + if(WITH_STATIC_LIB) add_definitions(-DSTATIC_LIB) endif() endif() @@ -83,38 +91,50 @@ if(WITH_GPU) if(NOT WIN32) include_directories("/usr/local/cuda/include") if(CUDA_LIB STREQUAL "") - set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library") + set(CUDA_LIB + "/usr/local/cuda/lib64/" + CACHE STRING "CUDA Library") endif() else() - include_directories("C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\include") + include_directories( + "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\include") if(CUDA_LIB STREQUAL "") - set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64") + set(CUDA_LIB + "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64" + ) endif() - endif(NOT WIN32) + endif() endif() -if (USE_TENSORRT AND WITH_GPU) - set(TENSORRT_ROOT "" CACHE STRING "The root directory of TensorRT library") +if(USE_TENSORRT AND WITH_GPU) + set(TENSORRT_ROOT + "" + CACHE STRING "The root directory of TensorRT library") if("${TENSORRT_ROOT}" STREQUAL "") - message(FATAL_ERROR "The TENSORRT_ROOT is empty, you must assign it a value with CMake command. Such as: -DTENSORRT_ROOT=TENSORRT_ROOT_PATH ") + message( + FATAL_ERROR + "The TENSORRT_ROOT is empty, you must assign it a value with CMake command. Such as: -DTENSORRT_ROOT=TENSORRT_ROOT_PATH " + ) endif() set(TENSORRT_INCLUDE_DIR ${TENSORRT_ROOT}/include) set(TENSORRT_LIB_DIR ${TENSORRT_ROOT}/lib) file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS) - string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION - "${TENSORRT_VERSION_FILE_CONTENTS}") + string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" + TENSORRT_MAJOR_VERSION "${TENSORRT_VERSION_FILE_CONTENTS}") if("${TENSORRT_MAJOR_VERSION}" STREQUAL "") - file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h TENSORRT_VERSION_FILE_CONTENTS) - string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION - "${TENSORRT_VERSION_FILE_CONTENTS}") + file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h + TENSORRT_VERSION_FILE_CONTENTS) + string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" + TENSORRT_MAJOR_VERSION "${TENSORRT_VERSION_FILE_CONTENTS}") endif() if("${TENSORRT_MAJOR_VERSION}" STREQUAL "") message(SEND_ERROR "Failed to detect TensorRT version.") endif() string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1" - TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}") - message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. " - "Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ") + TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}") + message( + STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. " + "Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ") include_directories("${TENSORRT_INCLUDE_DIR}") link_directories("${TENSORRT_LIB_DIR}") endif() @@ -126,79 +146,117 @@ if(WITH_MKL) set(MATH_LIB ${MATH_LIB_PATH}/lib/mklml${CMAKE_STATIC_LIBRARY_SUFFIX} ${MATH_LIB_PATH}/lib/libiomp5md${CMAKE_STATIC_LIBRARY_SUFFIX}) else() - set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} - ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(MATH_LIB + ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) endif() set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn") if(EXISTS ${MKLDNN_PATH}) include_directories("${MKLDNN_PATH}/include") if(WIN32) - set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib) - else(WIN32) - set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) - endif(WIN32) + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/dnnl.dll) + else() + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libdnnl.so.3) + endif() endif() -elseif((NOT WITH_MIPS) AND (NOT WITH_SW)) +elseif((NOT (WITH_MIPS OR WITH_LOONGARCH)) AND (NOT WITH_SW)) set(OPENBLAS_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}openblas") include_directories("${OPENBLAS_LIB_PATH}/include/openblas") if(WIN32) - set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/openblas${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(MATH_LIB + ${OPENBLAS_LIB_PATH}/lib/openblas${CMAKE_STATIC_LIBRARY_SUFFIX}) else() - set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(MATH_LIB + ${OPENBLAS_LIB_PATH}/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) endif() endif() if(WITH_STATIC_LIB) - set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS + ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX} + ) else() if(WIN32) - set(DEPS ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS + ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) else() - set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS + ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX} + ) endif() endif() -if (WITH_ONNXRUNTIME) +if(WITH_ONNXRUNTIME) if(WIN32) - set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.lib paddle2onnx) + set(DEPS + ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.lib + paddle2onnx) elseif(APPLE) - set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.1.10.0.dylib paddle2onnx) + set(DEPS + ${DEPS} + ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.1.10.0.dylib + paddle2onnx) else() - set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.so.1.10.0 paddle2onnx) + set(DEPS + ${DEPS} + ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.so.1.10.0 + paddle2onnx) endif() endif() -if (NOT WIN32) +if(NOT WIN32) set(EXTERNAL_LIB "-lrt -ldl -lpthread") - set(DEPS ${DEPS} - ${MATH_LIB} ${MKLDNN_LIB} - glog gflags protobuf xxhash cryptopp + set(DEPS + ${DEPS} + ${MATH_LIB} + ${MKLDNN_LIB} + glog + gflags + protobuf + xxhash + cryptopp ${EXTERNAL_LIB}) + if(WITH_SHARED_PHI) + set(DEPS ${DEPS} ${PADDLE_LIB}/paddle/lib/libphi${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() else() - set(DEPS ${DEPS} - ${MATH_LIB} ${MKLDNN_LIB} - glog gflags_static libprotobuf xxhash cryptopp-static ${EXTERNAL_LIB}) + set(DEPS + ${DEPS} + ${MATH_LIB} + ${MKLDNN_LIB} + glog + gflags_static + libprotobuf + xxhash + cryptopp-static + ${EXTERNAL_LIB}) set(DEPS ${DEPS} shlwapi.lib) -endif(NOT WIN32) +endif() if(WITH_GPU) if(NOT WIN32) - if (USE_TENSORRT) - set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}) - set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}) + if(USE_TENSORRT) + set(DEPS ${DEPS} + ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS + ${DEPS} + ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}) endif() set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) else() if(USE_TENSORRT) - set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) - set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} + ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} + ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) - set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} + ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_STATIC_LIBRARY_SUFFIX}) endif() endif() - set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) - set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) - set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX}) endif() endif() @@ -208,56 +266,97 @@ endif() if(WITH_XPU AND NOT WIN32) set(XPU_INSTALL_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}xpu") - set(DEPS ${DEPS} ${XPU_INSTALL_PATH}/lib/libxpuapi${CMAKE_SHARED_LIBRARY_SUFFIX}) - set(DEPS ${DEPS} ${XPU_INSTALL_PATH}/lib/libxpurt${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} + ${XPU_INSTALL_PATH}/lib/libxpuapi${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} + ${XPU_INSTALL_PATH}/lib/libxpurt${CMAKE_SHARED_LIBRARY_SUFFIX}) endif() if(WITH_NPU AND NOT WIN32) - set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libgraph${CMAKE_SHARED_LIBRARY_SUFFIX}) - set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libge_runner${CMAKE_SHARED_LIBRARY_SUFFIX}) - set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX}) - set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX}) - set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libacl_op_compiler${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS + ${DEPS} + ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libgraph${CMAKE_SHARED_LIBRARY_SUFFIX} + ) + set(DEPS + ${DEPS} + ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libge_runner${CMAKE_SHARED_LIBRARY_SUFFIX} + ) + set(DEPS + ${DEPS} + ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX} + ) + set(DEPS + ${DEPS} + ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX} + ) + set(DEPS + ${DEPS} + ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libacl_op_compiler${CMAKE_SHARED_LIBRARY_SUFFIX} + ) endif() add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) target_link_libraries(${DEMO_NAME} ${DEPS}) if(WIN32) if(USE_TENSORRT) - add_custom_command(TARGET ${DEMO_NAME} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_SHARED_LIBRARY_SUFFIX} - ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} - COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX} - ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} - ) + add_custom_command( + TARGET ${DEMO_NAME} + POST_BUILD + COMMAND + ${CMAKE_COMMAND} -E copy + ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_SHARED_LIBRARY_SUFFIX} + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + COMMAND + ${CMAKE_COMMAND} -E copy + ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX} + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) - add_custom_command(TARGET ${DEMO_NAME} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_SHARED_LIBRARY_SUFFIX} - ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}) + add_custom_command( + TARGET ${DEMO_NAME} + POST_BUILD + COMMAND + ${CMAKE_COMMAND} -E copy + ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_SHARED_LIBRARY_SUFFIX} + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}) endif() endif() if(WITH_MKL) - add_custom_command(TARGET ${DEMO_NAME} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/mklml.dll ${CMAKE_BINARY_DIR}/Release - COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/libiomp5md.dll ${CMAKE_BINARY_DIR}/Release - COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_PATH}/lib/mkldnn.dll ${CMAKE_BINARY_DIR}/Release - ) + add_custom_command( + TARGET ${DEMO_NAME} + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/mklml.dll + ${CMAKE_BINARY_DIR}/Release + COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/libiomp5md.dll + ${CMAKE_BINARY_DIR}/Release + COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_PATH}/lib/dnnl.dll + ${CMAKE_BINARY_DIR}/Release) else() - add_custom_command(TARGET ${DEMO_NAME} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${OPENBLAS_LIB_PATH}/lib/openblas.dll ${CMAKE_BINARY_DIR}/Release - ) + add_custom_command( + TARGET ${DEMO_NAME} + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${OPENBLAS_LIB_PATH}/lib/openblas.dll + ${CMAKE_BINARY_DIR}/Release) endif() if(WITH_ONNXRUNTIME) - add_custom_command(TARGET ${DEMO_NAME} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.dll - ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} - COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib/paddle2onnx.dll - ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} - ) + add_custom_command( + TARGET ${DEMO_NAME} + POST_BUILD + COMMAND + ${CMAKE_COMMAND} -E copy + ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.dll + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + COMMAND + ${CMAKE_COMMAND} -E copy + ${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib/paddle2onnx.dll + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}) endif() if(NOT WITH_STATIC_LIB) - add_custom_command(TARGET ${DEMO_NAME} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy "${PADDLE_LIB}/paddle/lib/paddle_inference.dll" ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} - ) + add_custom_command( + TARGET ${DEMO_NAME} + POST_BUILD + COMMAND + ${CMAKE_COMMAND} -E copy + "${PADDLE_LIB}/paddle/lib/paddle_inference.dll" + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}) endif() -endif() +endif() \ No newline at end of file diff --git a/example/auto_compression/pytorch_yolo_series/cpp_infer/README.md b/example/auto_compression/pytorch_yolo_series/cpp_infer/README.md index 0286c26df..9f7efdd0a 100644 --- a/example/auto_compression/pytorch_yolo_series/cpp_infer/README.md +++ b/example/auto_compression/pytorch_yolo_series/cpp_infer/README.md @@ -4,22 +4,22 @@ - CUDA、CUDNN:确认环境中已经安装CUDA和CUDNN,并且提前获取其安装路径。 -- TensorRT:可通过NVIDIA官网下载[TensorRT 8.4.1.5](https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.1/tars/tensorrt-8.4.1.5.linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz)或其他版本安装包。 +- TensorRT:可通过NVIDIA官网下载[TensorRT 8.6.1.6](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/8.6.1/tars/TensorRT-8.6.1.6.Linux.x86_64-gnu.cuda-11.8.tar.gz)或其他版本安装包。 -- Paddle Inference C++预测库:编译develop版本请参考[编译文档](https://www.paddlepaddle.org.cn/inference/user_guides/source_compile.html)。编译完成后,会在build目录下生成`paddle_inference_install_dir`文件夹,这个就是我们需要的C++预测库文件。 +- Paddle Inference C++预测库:编译develop版本请参考[编译文档](https://www.paddlepaddle.org.cn/inference/user_guides/source_compile.html)。编译完成后,会在build目录下生成`paddle_inference_install_dir`文件夹,这个就是我们需要的C++预测库文件,或者在官网下载[c++推理库](https://www.paddlepaddle.org.cn/inference/v2.6/guides/install/download_lib.html) ## 编译可执行程序 - (1)修改`compile.sh`中依赖库路径,主要是以下内容: ```shell # Paddle Inference预测库路径 -LIB_DIR=/root/auto_compress/Paddle/build/paddle_inference_install_dir/ +LIB_DIR=/work/Paddle/build/paddle_inference_install_dir # CUDNN路径 CUDNN_LIB=/usr/lib/x86_64-linux-gnu/ # CUDA路径 CUDA_LIB=/usr/local/cuda/lib64 # TensorRT安装包路径,为TRT资源包解压完成后的绝对路径,其中包含`lib`和`include`文件夹 -TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/ +TENSORRT_ROOT=/work/TensorRT-8.6.1.6 ``` ## Paddle tensorRT测试 @@ -27,59 +27,42 @@ TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/ - YOLOv5 ``` # FP32 -./build/trt_run --model_file yolov5s_infer/model.pdmodel --params_file yolov5s_infer/model.pdiparams --run_mode=trt_fp32 +./build/trt_run --model_file yolov5_model/inference_model/model.pdmodel --params_file yolov5_model/inference_model/model.pdiparams --run_mode=trt_fp32 # FP16 -./build/trt_run --model_file yolov5s_infer/model.pdmodel --params_file yolov5s_infer/model.pdiparams --run_mode=trt_fp16 +./build/trt_run --model_file yolov5_model/inference_model/model.pdmodel --params_file yolov5_model/inference_model/model.pdiparams --run_mode=trt_fp16 # INT8 -./build/trt_run --model_file yolov5s_quant/model.pdmodel --params_file yolov5s_quant/model.pdiparams --run_mode=trt_int8 +./build/trt_run --model_file yolov5s_quantaware/model.pdmodel --params_file yolov5s_quantaware/model.pdiparams --run_mode=trt_int8 ``` - YOLOv6 ``` # FP32 -./build/trt_run --arch=YOLOv6 --model_file yolov6s_infer/model.pdmodel --params_file yolov6s_infer/model.pdiparams --run_mode=trt_fp32 +./build/trt_run --model_file yolov6_model/inference_model/model.pdmodel --params_file yolov6_model/inference_model/model.pdiparams --run_mode=trt_fp32 # FP16 -./build/trt_run --arch=YOLOv6 --model_file yolov6s_infer/model.pdmodel --params_file yolov6s_infer/model.pdiparams --run_mode=trt_fp16 +./build/trt_run --model_file yolov6_model/inference_model/model.pdmodel --params_file yolov6_model/inference_model/model.pdiparams --run_mode=trt_fp16 # INT8 -./build/trt_run --arch=YOLOv6 --model_file yolov6s_quant/model.pdmodel --params_file yolov6s_quant/model.pdiparams --run_mode=trt_int8 +./build/trt_run --model_file yolov6s_quantaware/model.pdmodel --params_file yolov6s_quantaware/model.pdiparams --run_mode=trt_int8 ``` - YOLOv7 ``` # FP32 -./build/trt_run --model_file yolov7_infer/model.pdmodel --params_file yolov7_infer/model.pdiparams --run_mode=trt_fp32 +./build/trt_run --model_file yolov7-tiny/inference_model/model.pdmodel --params_file yolov7-tiny/inference_model/model.pdiparams --run_mode=trt_fp32 # FP16 -./build/trt_run --model_file yolov7_infer/model.pdmodel --params_file yolov7_infer/model.pdiparams --run_mode=trt_fp16 +./build/trt_run --model_file yolov7-tiny/inference_model/model.pdmodel --params_file yolov7-tiny/inference_model/model.pdiparams --run_mode=trt_fp16 # INT8 -./build/trt_run --model_file yolov7_quant/model.pdmodel --params_file yolov7_quant/model.pdiparams --run_mode=trt_int8 +./build/trt_run --model_file yolov7-quantAware/model.pdmodel --params_file yolov7-quantAware/model.pdiparams --run_mode=trt_int8 ``` ## 原生TensorRT测试 ```shell # FP32 -trtexec --onnx=yolov5s.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw +trtexec --onnx=yolov5s.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp32:chw --outputIOFormats=fp32:chw # FP16 -trtexec --onnx=yolov5s.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --fp16 +trtexec --onnx=yolov5s.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp32:chw --outputIOFormats=fp32:chw --fp16 # INT8 -trtexec --onnx=yolov5s.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --int8 +trtexec --onnx=yolov5s.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp32:chw --outputIOFormats=fp32:chw --int8 ``` - 注:可把--onnx=yolov5s.onnx替换成yolov6s.onnx和yolov7.onnx模型 - -## 性能对比 - -| 预测库 | 模型 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | -| :--------: | :--------: |:-------- |:--------: | :---------------------: | -| Paddle TensorRT | yolov5s | 5.95ms | 2.44ms | 1.87ms | -| TensorRT | yolov5s | 6.16ms | 2.58ms | 2.07ms | -| | | | | | -| Paddle TensorRT | YOLOv6s | 9.06ms | 2.90ms | 1.83ms | -| TensorRT | YOLOv6s | 8.59ms | 2.83ms | 1.87ms | -| | | | | | -| Paddle TensorRT | YOLOv7 | 26.84ms | 7.44ms | 4.55ms | -| TensorRT | YOLOv7 | 28.25ms | 7.23ms | 4.67ms | - -环境: -- Tesla T4,TensorRT 8.4.1,CUDA 11.2 -- batch_size=1 diff --git a/example/auto_compression/pytorch_yolo_series/cpp_infer/compile.sh b/example/auto_compression/pytorch_yolo_series/cpp_infer/compile.sh index afff924b4..ac164faa9 100644 --- a/example/auto_compression/pytorch_yolo_series/cpp_infer/compile.sh +++ b/example/auto_compression/pytorch_yolo_series/cpp_infer/compile.sh @@ -14,10 +14,10 @@ WITH_MKL=ON WITH_GPU=ON USE_TENSORRT=ON -LIB_DIR=/root/auto_compress/Paddle/build/paddle_inference_install_dir/ +LIB_DIR=/work/Paddle/build/paddle_inference_install_dir CUDNN_LIB=/usr/lib/x86_64-linux-gnu/ CUDA_LIB=/usr/local/cuda/lib64 -TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/ +TENSORRT_ROOT=/work/TensorRT-8.6.1.6 WITH_ROCM=OFF ROCM_LIB=/opt/rocm/lib diff --git a/example/auto_compression/pytorch_yolo_series/cpp_infer/trt_run.cc b/example/auto_compression/pytorch_yolo_series/cpp_infer/trt_run.cc index 101cf33c8..56c10ab72 100644 --- a/example/auto_compression/pytorch_yolo_series/cpp_infer/trt_run.cc +++ b/example/auto_compression/pytorch_yolo_series/cpp_infer/trt_run.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include #include #include @@ -5,27 +19,28 @@ #include #include -#include -#include "paddle/include/paddle_inference_api.h" -#include "paddle/include/experimental/phi/common/float16.h" +#include "paddle_inference_api.h" using paddle_infer::Config; -using paddle_infer::Predictor; using paddle_infer::CreatePredictor; using paddle_infer::PrecisionType; -using phi::dtype::float16; +using paddle_infer::Predictor; DEFINE_string(model_dir, "", "Directory of the inference model."); DEFINE_string(model_file, "", "Path of the inference model file."); DEFINE_string(params_file, "", "Path of the inference params file."); -DEFINE_string(arch, "YOLOv5", "Architectures name, can be: YOLOv5, YOLOv6, YOLOv7."); -DEFINE_string(run_mode, "trt_fp32", "run_mode which can be: trt_fp32, trt_fp16 and trt_int8"); +DEFINE_string( + run_mode, + "paddle_gpu", + "run_mode which can be: trt_fp32, trt_fp16 and trt_int8 and paddle_gpu"); DEFINE_int32(batch_size, 1, "Batch size."); DEFINE_int32(gpu_id, 0, "GPU card ID num."); DEFINE_int32(trt_min_subgraph_size, 3, "tensorrt min_subgraph_size"); DEFINE_int32(warmup, 50, "warmup"); DEFINE_int32(repeats, 1000, "repeats"); +DEFINE_bool(use_dynamic_shape, false, "use trt dynaminc shape."); +DEFINE_bool(use_calib, true, "use trt int8 calibration."); using Time = decltype(std::chrono::high_resolution_clock::now()); Time time() { return std::chrono::high_resolution_clock::now(); }; @@ -38,89 +53,89 @@ double time_diff(Time t1, Time t2) { std::shared_ptr InitPredictor() { Config config; - std::string model_path; if (FLAGS_model_dir != "") { config.SetModel(FLAGS_model_dir); - model_path = FLAGS_model_dir.substr(0, FLAGS_model_dir.find_last_of("/")); - } else { - config.SetModel(FLAGS_model_file, FLAGS_params_file); - model_path = FLAGS_model_file.substr(0, FLAGS_model_file.find_last_of("/")); } - // enable tune - std::cout << "model_path: " << model_path << std::endl; - config.EnableUseGpu(256, FLAGS_gpu_id); + config.SetModel(FLAGS_model_file, FLAGS_params_file); + + config.EnableUseGpu(500, FLAGS_gpu_id); + if (FLAGS_run_mode == "trt_fp32") { - config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size, - PrecisionType::kFloat32, false, false); + config.EnableTensorRtEngine(1 << 30 * FLAGS_batch_size, + FLAGS_batch_size, + FLAGS_trt_min_subgraph_size, + PrecisionType::kFloat32, + false, + false); } else if (FLAGS_run_mode == "trt_fp16") { - config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size, - PrecisionType::kHalf, false, false); + config.EnableTensorRtEngine(1 << 30 * FLAGS_batch_size, + FLAGS_batch_size, + FLAGS_trt_min_subgraph_size, + PrecisionType::kHalf, + false, + false); } else if (FLAGS_run_mode == "trt_int8") { - config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size, - PrecisionType::kInt8, false, false); + config.EnableTensorRtEngine(1 << 30 * FLAGS_batch_size, + FLAGS_batch_size, + FLAGS_trt_min_subgraph_size, + PrecisionType::kInt8, + false, + FLAGS_use_calib); + } + if (FLAGS_use_dynamic_shape) { + std::map> min_input_shape = { + {"image", {1, 3, 640, 640}}}; + std::map> max_input_shape = { + {"image", {4, 3, 640, 640}}}; + std::map> opt_input_shape = { + {"image", {2, 3, 640, 640}}}; + config.SetTRTDynamicShapeInfo( + min_input_shape, max_input_shape, opt_input_shape); } + // Open the memory optim. config.EnableMemoryOptim(); + config.SwitchIrDebug(true); config.SwitchIrOptim(true); return CreatePredictor(config); } -template -void run(Predictor *predictor, const std::vector &input, - const std::vector &input_shape, type* out_data, std::vector out_shape) { - - // prepare input - int input_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, - std::multiplies()); - - auto input_names = predictor->GetInputNames(); - auto input_t = predictor->GetInputHandle(input_names[0]); - input_t->Reshape(input_shape); - input_t->CopyFromCpu(input.data()); - - for (int i = 0; i < FLAGS_warmup; ++i) - CHECK(predictor->Run()); +void run(Predictor *predictor, + const std::vector &input, + const std::vector &input_shape, + std::vector *out_data) { + int input_num = std::accumulate( + input_shape.begin(), input_shape.end(), 1, std::multiplies()); - auto st = time(); - for (int i = 0; i < FLAGS_repeats; ++i) { - auto input_names = predictor->GetInputNames(); - auto input_t = predictor->GetInputHandle(input_names[0]); + auto input_names = predictor->GetInputNames(); + auto output_names = predictor->GetOutputNames(); + auto input_t = predictor->GetInputHandle(input_names[0]); + input_t->Reshape(input_shape); + input_t->CopyFromCpu(input.data()); - input_t->Reshape(input_shape); - input_t->CopyFromCpu(input.data()); + for (size_t i = 0; i < FLAGS_warmup; ++i) CHECK(predictor->Run()); + auto st = time(); + for (size_t i = 0; i < FLAGS_repeats; ++i) { CHECK(predictor->Run()); - - auto output_names = predictor->GetOutputNames(); auto output_t = predictor->GetOutputHandle(output_names[0]); std::vector output_shape = output_t->shape(); - output_t->CopyToCpu(out_data); - + int out_num = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); + out_data->resize(out_num); + output_t->CopyToCpu(out_data->data()); } - - LOG(INFO) << "[" << FLAGS_run_mode << " bs-" << FLAGS_batch_size << " ] run avg time is " << time_diff(st, time()) / FLAGS_repeats + LOG(INFO) << "run avg time is " << time_diff(st, time()) / FLAGS_repeats << " ms"; } -int main(int argc, char *argv[]) -{ +int main(int argc, char *argv[]) { google::ParseCommandLineFlags(&argc, &argv, true); auto predictor = InitPredictor(); - - std::cout << "====== Use float instead of FP16 data ======" << std::endl; - std::vector input_data(FLAGS_batch_size * 3 * 640 * 640, float(1.0)); std::vector input_shape = {FLAGS_batch_size, 3, 640, 640}; + std::vector input_data(FLAGS_batch_size * 3 * 640 * 640); + for (size_t i = 0; i < input_data.size(); ++i) input_data[i] = i % 255 * 0.1; + std::vector out_data; + run(predictor.get(), input_data, input_shape, &out_data); - int out_box_shape = 25200; - if (FLAGS_arch == "YOLOv6"){ - out_box_shape = 8400; - } - float* out_data; - std::vector out_shape{ FLAGS_batch_size, 1, out_box_shape, 85}; - int out_data_size = FLAGS_batch_size * out_box_shape * 85; - - // Only use Pinned mem for D2H. - cudaHostAlloc((void**)&out_data, sizeof(float) * out_data_size, cudaHostAllocMapped); - - run(predictor.get(), input_data, input_shape, out_data, out_shape); return 0; } \ No newline at end of file diff --git a/example/auto_compression/pytorch_yolo_series/paddle_inference_eval.py b/example/auto_compression/pytorch_yolo_series/paddle_inference_eval.py index a1df31b78..382336613 100644 --- a/example/auto_compression/pytorch_yolo_series/paddle_inference_eval.py +++ b/example/auto_compression/pytorch_yolo_series/paddle_inference_eval.py @@ -79,7 +79,8 @@ def argsparser(): "--device", type=str, default="GPU", - help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU", + help= + "Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU", ) parser.add_argument( "--arch", type=str, default="YOLOv5", help="architectures name.") @@ -180,8 +181,9 @@ def draw_box(img, boxes, scores, cls_ids, conf=0.5, class_names=None): txt_size = cv2.getTextSize(text, font, 0.4, 1)[0] cv2.rectangle(img, (x0, y0), (x1, y1), color, 2) - cv2.rectangle(img, (x0, y0 + 1), ( - x0 + txt_size[0] + 1, y0 + int(1.5 * txt_size[1])), color, -1) + cv2.rectangle(img, (x0, y0 + 1), (x0 + txt_size[0] + 1, + y0 + int(1.5 * txt_size[1])), color, + -1) cv2.putText( img, text, (x0, y0 + txt_size[1]), @@ -288,8 +290,8 @@ def load_predictor( dynamic_shape_file = os.path.join(FLAGS.model_path, "dynamic_shape.txt") if os.path.exists(dynamic_shape_file): - config.enable_tuned_tensorrt_dynamic_shape(dynamic_shape_file, - True) + config.enable_tuned_tensorrt_dynamic_shape( + dynamic_shape_file, True) print("trt set dynamic shape done!") else: config.collect_shape_range_info(dynamic_shape_file) @@ -315,7 +317,8 @@ def eval(predictor, val_loader, anno_file, rerun_flag=False): input_names = predictor.get_input_names() output_names = predictor.get_output_names() boxes_tensor = predictor.get_output_handle(output_names[0]) - for batch_id, data in enumerate(val_loader): + for batch_id, data in tqdm( + enumerate(val_loader), total=len(val_loader), desc='Evaluating'): data_all = {k: np.array(v) for k, v in data.items()} inputs = {} if FLAGS.arch == "YOLOv6": @@ -345,7 +348,6 @@ def eval(predictor, val_loader, anno_file, rerun_flag=False): cpu_mems += cpu_mem gpu_mems += gpu_mem if batch_id % 100 == 0: - print("Eval iter:", batch_id) sys.stdout.flush() print("[Benchmark]Avg cpu_mem:{} MB, avg gpu_mem: {} MB".format( cpu_mems / sample_nums, gpu_mems / sample_nums)) @@ -469,4 +471,4 @@ def main(): # DataLoader need run on cpu paddle.set_device("cpu") - main() + main() \ No newline at end of file diff --git a/example/post_training_quantization/pytorch_yolo_series/README.md b/example/post_training_quantization/pytorch_yolo_series/README.md index 63a7d96c1..782192d89 100755 --- a/example/post_training_quantization/pytorch_yolo_series/README.md +++ b/example/post_training_quantization/pytorch_yolo_series/README.md @@ -35,13 +35,13 @@ | YOLOv7 | KL离线量化 | 640*640 | 50.2 | - | - | 4.55ms | - | - | 说明: -- mAP的指标均在COCO val2017数据集中评测得到。 +- mAP的指标均在COCO val2017数据集中评测得到。以上指标通过c++测速得到 ## 3. 离线量化流程 #### 3.1 准备环境 -- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) -- PaddleSlim > 2.3版本 +- PaddlePaddle 2.6 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) +- PaddleSlim 2.6 - X2Paddle >= 1.3.9 - opencv-python @@ -49,9 +49,9 @@ (1)安装paddlepaddle: ```shell # CPU -pip install paddlepaddle -# GPU -pip install paddlepaddle-gpu +python -m pip install paddlepaddle==2.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple +# GPU cuda11.2为例 +python -m pip install paddlepaddle-gpu==2.6.0.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html ``` (2)安装paddleslim: @@ -122,7 +122,7 @@ python eval.py --config_path=./configs/yolov5s_ptq.yaml #### 3.6 提高离线量化精度 ###### 3.6.1 量化分析工具 -本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisPTQ```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。```paddleslim.quant.AnalysisPTQ```详解见[AnalysisPTQ.md](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/tutorials/quant/post_training_quantization.md)。 +本节介绍如何使用量化分析工具提升离线量化精度。离线量化功能仅需使用少量数据,且使用简单、能快速得到量化模型,但往往会造成较大的精度损失。PaddleSlim提供量化分析工具,会使用接口```paddleslim.quant.AnalysisPTQ```,可视化展示出不适合量化的层,通过跳过这些层,提高离线量化模型精度。```paddleslim.quant.AnalysisPTQ```详解见[离线量化](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/tutorials/quant/post_training_quantization.md)。 由于YOLOv6离线量化效果较差,以YOLOv6为例,量化分析工具具体使用方法如下: @@ -162,7 +162,7 @@ python post_quant.py --config_path=./configs/yolov6s_analyzed_ptq.yaml --save_di 注:分析之后若需要直接产出符合目标精度的量化模型,demo代码不会使用少量数据集验证,会自动使用全量验证数据。 -量化分析工具详细介绍见[量化分析工具介绍](../analysis.md) +量化分析工具详细介绍见[量化分析工具介绍](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/tutorials/quant/static/Analysis.md) ###### 3.6.2 精度重构工具 本节介绍如何使用精度重构工具提高精度。该工具的思想是,通过最小化量化前后模型输出的重构误差(minimizing the reconstruction error,MRE),学习权重的取整方式(上取整or下取整),从而`fine-tune`经量化后的模型的权重,提高精度。同样以YOLOv6为例,运行命令如下: @@ -215,16 +215,15 @@ python fine_tune.py --config_path=./configs/yolov6s_fine_tune.yaml --simulate_ac | val_image_dir | COCO数据集中验证图像的目录名,默认为val2017 | | val_anno_path | 指定COCO数据集的注释(annotation)文件路径,这是包含验证集标注信息的JSON文件,默认为annotations/instances_val2017.json | | benchmark | 指定是否运行性能基准测试。如果设置为True,程序将会进行性能测试 | -| device | 使用GPU或者CPU预测,可选CPU/GPU/XPU,默认设置为GPU | -| use_trt | 是否使用TensorRT进行预测| -| use_mkldnn | 是否使用MKL-DNN加速库,注意use_mkldnn与use_gpu同时为True时,将忽略enable_mkldnn,而使用GPU预测| -| use_dynamic_shape | 是否使用动态形状(dynamic_shape)功能 | -| precision | fp32/fp16/int8| +| device | 使用GPU或者CPU预测,可选CPU/GPU/XPU,默认设置为GPU | +| use_trt | 是否使用 TesorRT 预测引擎 | +| use_mkldnn | 是否启用```MKL-DNN```加速库,注意```use_mkldnn```与```use_gpu```同时为```True```时,将忽略```enable_mkldnn```,而使用```GPU```预测 | +| cpu_threads | CPU预测时,使用CPU线程数量,默认10 | +| precision | 预测精度,包括`fp32/fp16/int8` | | arch | 指定所使用的模型架构的名称,例如YOLOv5 | | img_shape | 指定模型输入的图像尺寸 | +| use_dynamic_shape | 是否使用动态shape,如果使用动态shape,则设置为True,否则设置为False | | batch_size | 指定模型输入的批处理大小 | -| use_mkldnn | 指定是否使用MKLDNN加速(主要针对CPU)| -| cpu_threads | 指定在CPU上使用的线程数 | 首先,我们拥有的yolov6.onnx,我们需要把ONNX模型转成paddle模型,具体参考使用[X2Paddle迁移推理模型](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/model_convert/convert_with_x2paddle_cn.html#x2paddle) - 安装X2Paddle @@ -242,7 +241,7 @@ python setup.py install ```shell x2paddle --framework=onnx --model=yolov6s.onnx --save_dir=yolov6_model ``` -- TensorRT Python部署 +#### 4.1 TensorRT Python部署 使用[paddle_inference_eval.py](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/example/auto_compression/pytorch_yolo_series/paddle_inference_eval.py)部署 ```shell python paddle_inference_eval.py --model_path=yolov6_model/inference_model --dataset_dir=datasets/coco --use_trt=True --precision=fp32 --arch=YOLOv6 @@ -251,7 +250,11 @@ python paddle_inference_eval.py --model_path=yolov6_model/inference_model --data ```shell python paddle_inference_eval.py --model_path=yolov6s_ptq_out --dataset_dir==datasets/coco --use_trt=True --precision=int8 --arch=YOLOv6 ``` -- C++部署 +#### 4.2 MKLDNN Python部署 +```shell +python paddle_inference_eval.py --model_path=yolov6_model/inference_model --dataset_dir=/work/GETR-Lite-paddle-new/inference/datasets/coco --device=CPU --use_mkldnn=True --precision=fp32 --arch=YOLOv6 +``` +#### 4.3 C++部署 具体可参考[运行PP-YOLOE-l目标检测模型样例](https://github.com/PaddlePaddle/Paddle-Inference-Demo/tree/master/c%2B%2B/gpu/ppyoloe_crn_l) 将compile.sh中DEMO_NAME修改为yolov6_test,并且将ppyoloe_crn_l.cc修改为yolov6_test.cc,根据环境修改相关配置库 运行bash compile.sh编译样例。