Skip to content

Commit

Permalink
separate Paddle Predictor to copy2gpu, infer, copy2cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
TingquanGao committed Nov 8, 2024
1 parent f8357ca commit b73c503
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 62 deletions.
58 changes: 31 additions & 27 deletions docs/module_usage/instructions/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,44 @@ python main.py \
在开启 Benchmark 后,将自动打印 benchmark 指标:

```
+-------------------+-----------------+------+---------------+
| Stage | Total Time (ms) | Nums | Avg Time (ms) |
+-------------------+-----------------+------+---------------+
| ReadCmp | 49.95107651 | 10 | 4.99510765 |
| Resize | 8.48054886 | 10 | 0.84805489 |
| Normalize | 23.08964729 | 10 | 2.30896473 |
| ToCHWImage | 0.02717972 | 10 | 0.00271797 |
| ImageDetPredictor | 75.94108582 | 10 | 7.59410858 |
| DetPostProcess | 0.26535988 | 10 | 0.02653599 |
+-------------------+-----------------+------+---------------+
+----------------+-----------------+------+---------------+
| Stage | Total Time (ms) | Nums | Avg Time (ms) |
+----------------+-----------------+------+---------------+
| ReadCmp | 185.48870087 | 10 | 18.54887009 |
| Resize | 16.95227623 | 30 | 0.56507587 |
| Normalize | 41.12100601 | 30 | 1.37070020 |
| ToCHWImage | 0.05745888 | 30 | 0.00191530 |
| Copy2GPU | 14.58549500 | 10 | 1.45854950 |
| Infer | 100.14462471 | 10 | 10.01446247 |
| Copy2CPU | 9.54508781 | 10 | 0.95450878 |
| DetPostProcess | 0.56767464 | 30 | 0.01892249 |
+----------------+-----------------+------+---------------+
+-------------+-----------------+------+---------------+
| Stage | Total Time (ms) | Nums | Avg Time (ms) |
+-------------+-----------------+------+---------------+
| PreProcess | 81.54845238 | 10 | 8.15484524 |
| Inference | 75.94108582 | 10 | 7.59410858 |
| PostProcess | 0.26535988 | 10 | 0.02653599 |
| End2End | 161.07797623 | 10 | 16.10779762 |
| WarmUp | 5496.41847610 | 5 | 1099.28369522 |
| PreProcess | 243.61944199 | 30 | 8.12064807 |
| Inference | 124.27520752 | 30 | 4.14250692 |
| PostProcess | 0.56767464 | 30 | 0.01892249 |
| End2End | 379.70948219 | 30 | 12.65698274 |
| WarmUp | 9465.68179131 | 5 | 1893.13635826 |
+-------------+-----------------+------+---------------+
```

在 Benchmark 结果中,会统计该模型全部组件(`Component`)的总耗时(`Total Time`,单位为“毫秒”)、调用次数(`Nums`)、调用平均执行耗时`Avg Time`,单位为“毫秒”),以及按预热(`WarmUp`)、预处理(`PreProcess`)、模型推理(`Inference`)、后处理(`PostProcess`)和端到端(`End2End`)进行划分的耗时统计,包括每个阶段的总耗时(`Total Time`,单位为“毫秒”)、样本数(`Nums`和单样本平均执行耗时`Avg Time`,单位为“毫秒”),同时,保存相关指标会到本地 `./benchmark.csv` 文件中:
在 Benchmark 结果中,会统计该模型全部组件(`Component`)的总耗时(`Total Time`,单位为“毫秒”)、**调用次数**`Nums`)、**调用**平均执行耗时`Avg Time`,单位为“毫秒”),以及按预热(`WarmUp`)、预处理(`PreProcess`)、模型推理(`Inference`)、后处理(`PostProcess`)和端到端(`End2End`)进行划分的耗时统计,包括每个阶段的总耗时(`Total Time`,单位为“毫秒”)、**样本数**`Nums`**单样本**平均执行耗时`Avg Time`,单位为“毫秒”),同时,保存相关指标会到本地 `./benchmark.csv` 文件中:

```csv
Stage,Total Time (ms),Nums,Avg Time (ms)
ReadCmp,0.04995107650756836,10,0.004995107650756836
Resize,0.008480548858642578,10,0.0008480548858642578
Normalize,0.02308964729309082,10,0.002308964729309082
ToCHWImage,2.7179718017578125e-05,10,2.7179718017578126e-06
ImageDetPredictor,0.07594108581542969,10,0.007594108581542969
DetPostProcess,0.00026535987854003906,10,2.6535987854003906e-05
PreProcess,0.08154845237731934,10,0.008154845237731934
Inference,0.07594108581542969,10,0.007594108581542969
PostProcess,0.00026535987854003906,10,2.6535987854003906e-05
End2End,0.16107797622680664,10,0.016107797622680664
WarmUp,5.496418476104736,5,1.0992836952209473
ReadCmp,0.18548870086669922,10,0.018548870086669923
Resize,0.0169522762298584,30,0.0005650758743286133
Normalize,0.04112100601196289,30,0.001370700200398763
ToCHWImage,5.745887756347656e-05,30,1.915295918782552e-06
Copy2GPU,0.014585494995117188,10,0.0014585494995117188
Infer,0.10014462471008301,10,0.0100144624710083
Copy2CPU,0.009545087814331055,10,0.0009545087814331055
DetPostProcess,0.0005676746368408203,30,1.892248789469401e-05
PreProcess,0.24361944198608398,30,0.0081206480662028
Inference,0.12427520751953125,30,0.0041425069173177086
PostProcess,0.0005676746368408203,30,1.892248789469401e-05
End2End,0.37970948219299316,30,0.012656982739766438
WarmUp,9.465681791305542,5,1.8931363582611085
```
2 changes: 1 addition & 1 deletion paddlex/configs/object_detection/PicoDet-S.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Export:
weight_path: https://paddledet.bj.bcebos.com/models/picodet_s_320_coco_lcnet.pdparams

Predict:
batch_size: 1
batch_size: 3
model_dir: "output/best_model/inference"
input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png"
kernel_option:
Expand Down
7 changes: 7 additions & 0 deletions paddlex/inference/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def _check_args_key(args):
f"The parameter ({param.name}) is needed by {self.__class__.__name__}, but {list(args.keys())} only found!"
)

if self.inputs is None:
return [({}, None)]

if self.need_batch_input:
args = {}
for input_ in input_list:
Expand Down Expand Up @@ -266,6 +269,10 @@ def keep_input(self):
def name(self):
return getattr(self, "NAME", self.__class__.__name__)

@property
def sub_cmps(self):
return None

@abstractmethod
def apply(self, input):
raise NotImplementedError
Expand Down
93 changes: 69 additions & 24 deletions paddlex/inference/components/paddle_predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,42 @@
from ..base import BaseComponent


class Copy2GPU(BaseComponent):

def __init__(self, input_handlers):
super().__init__()
self.input_handlers = input_handlers

def apply(self, x):
for idx in range(len(x)):
self.input_handlers[idx].reshape(x[idx].shape)
self.input_handlers[idx].copy_from_cpu(x[idx])


class Copy2CPU(BaseComponent):

def __init__(self, output_handlers):
super().__init__()
self.output_handlers = output_handlers

def apply(self):
output = []
for out_tensor in self.output_handlers:
batch = out_tensor.copy_to_cpu()
output.append(batch)
return output


class Infer(BaseComponent):

def __init__(self, predictor):
super().__init__()
self.predictor = predictor

def apply(self):
self.predictor.run()


class BasePaddlePredictor(BaseComponent):
"""Predictor based on Paddle Inference"""

Expand Down Expand Up @@ -56,12 +92,13 @@ def _reset(self):
self.option = PaddlePredictorOption()
logging.debug(f"Env: {self.option}")
(
self.predictor,
self.inference_config,
self.input_names,
self.input_handlers,
self.output_handlers,
predictor,
input_handlers,
output_handlers,
) = self._create()
self.copy2gpu = Copy2GPU(input_handlers)
self.copy2cpu = Copy2CPU(output_handlers)
self.infer = Infer(predictor)
self.option.changed = False

def _create(self):
Expand Down Expand Up @@ -169,43 +206,46 @@ def _create(self):
for output_name in output_names:
output_handler = predictor.get_output_handle(output_name)
output_handlers.append(output_handler)
return predictor, config, input_names, input_handlers, output_handlers

def get_input_names(self):
"""get input names"""
return self.input_names
return predictor, input_handlers, output_handlers

def apply(self, **kwargs):
if self.option.changed:
self._reset()
x = self.to_batch(**kwargs)
for idx in range(len(x)):
self.input_handlers[idx].reshape(x[idx].shape)
self.input_handlers[idx].copy_from_cpu(x[idx])

self.predictor.run()
output = []
for out_tensor in self.output_handlers:
batch = out_tensor.copy_to_cpu()
output.append(batch)
return self.format_output(output)
batches = self.to_batch(**kwargs)
self.copy2gpu.apply(batches)
self.infer.apply()
pred = self.copy2cpu.apply()
return self.format_output(pred)

def format_output(self, pred):
return [{"pred": res} for res in zip(*pred)]
@property
def sub_cmps(self):
return {
"Copy2GPU": self.copy2gpu,
"Infer": self.infer,
"Copy2CPU": self.copy2cpu,
}

@abstractmethod
def to_batch(self):
raise NotImplementedError

@abstractmethod
def format_output(self, pred):
return [{"pred": res} for res in zip(*pred)]

class ImagePredictor(BasePaddlePredictor):

class ImagePredictor(BasePaddlePredictor):
INPUT_KEYS = "img"
OUTPUT_KEYS = "pred"
DEAULT_INPUTS = {"img": "img"}
DEAULT_OUTPUTS = {"pred": "pred"}

def to_batch(self, img):
return [np.stack(img, axis=0).astype(dtype=np.float32, copy=False)]

def format_output(self, pred):
return [{"pred": res} for res in zip(*pred)]


class ImageDetPredictor(BasePaddlePredictor):

Expand Down Expand Up @@ -276,9 +316,14 @@ def format_output(self, pred):
class TSPPPredictor(BasePaddlePredictor):

INPUT_KEYS = "ts"
OUTPUT_KEYS = "pred"
DEAULT_INPUTS = {"ts": "ts"}
DEAULT_OUTPUTS = {"pred": "pred"}

def to_batch(self, ts):
n = len(ts[0])
x = [np.stack([lst[i] for lst in ts], axis=0) for i in range(n)]
return x

def format_output(self, pred):
return [{"pred": res} for res in zip(*pred)]
33 changes: 23 additions & 10 deletions paddlex/inference/utils/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,42 @@ def warmup_stop(self, warmup_num):
self._reset()

def _reset(self):
for name in self._components:
cmp = self._components[name]
for name, cmp in self.iterate_cmp(self._components):
cmp.timer.reset()
self._e2e_tic = time.time()

def iterate_cmp(self, cmps):
if cmps is None:
return
for name, cmp in cmps.items():
if cmp.sub_cmps is not None:
yield from self.iterate_cmp(cmp.sub_cmps)
yield name, cmp

def gather(self, e2e_num):
# lazy import for avoiding circular import
from ..components.paddle_predictor import BasePaddlePredictor

detail = []
summary = {"preprocess": 0, "inference": 0, "postprocess": 0}
op_tag = "preprocess"
for name in self._components:
cmp = self._components[name]
times = cmp.timer.logs
counts = len(times)
avg = np.mean(times)
total = np.sum(times)
detail.append((name, total, counts, avg))
for name, cmp in self._components.items():
if isinstance(cmp, BasePaddlePredictor):
summary["inference"] += total
# TODO(gaotingquan): show by hierarchy. Now dont show xxxPredictor benchmark info to ensure mutual exclusivity between components.
for name, sub_cmp in cmp.sub_cmps.items():
times = sub_cmp.timer.logs
counts = len(times)
avg = np.mean(times)
total = np.sum(times)
detail.append((name, total, counts, avg))
summary["inference"] += total
op_tag = "postprocess"
else:
times = cmp.timer.logs
counts = len(times)
avg = np.mean(times)
total = np.sum(times)
detail.append((name, total, counts, avg))
summary[op_tag] += total

summary = [
Expand Down

0 comments on commit b73c503

Please sign in to comment.