forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Speedup the Video Inference by Accelerating data-loading Stage (open-…
…mmlab#7832) * add a faster inference for video * Fix typos * modify typo * modify the numpy array to torch gpu * fix lint * add description * add documents * fix typro * fix lint * fix lint * fix lint again * fix a mistake
- Loading branch information
1 parent
280cc7d
commit b1f40ef
Showing
3 changed files
with
166 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import argparse | ||
|
||
import cv2 | ||
import mmcv | ||
import numpy as np | ||
import torch | ||
from torchvision.transforms import functional as F | ||
|
||
from mmdet.apis import init_detector | ||
from mmdet.datasets.pipelines import Compose | ||
|
||
try: | ||
import ffmpegcv | ||
except ImportError: | ||
raise ImportError( | ||
'Please install ffmpegcv with:\n\n pip install ffmpegcv') | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='MMDetection video demo with GPU acceleration') | ||
parser.add_argument('video', help='Video file') | ||
parser.add_argument('config', help='Config file') | ||
parser.add_argument('checkpoint', help='Checkpoint file') | ||
parser.add_argument( | ||
'--device', default='cuda:0', help='Device used for inference') | ||
parser.add_argument( | ||
'--score-thr', type=float, default=0.3, help='Bbox score threshold') | ||
parser.add_argument('--out', type=str, help='Output video file') | ||
parser.add_argument('--show', action='store_true', help='Show video') | ||
parser.add_argument( | ||
'--nvdecode', action='store_true', help='Use NVIDIA decoder') | ||
parser.add_argument( | ||
'--wait-time', | ||
type=float, | ||
default=1, | ||
help='The interval of show (s), 0 is block') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def prefetch_img_metas(cfg, ori_wh): | ||
w, h = ori_wh | ||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | ||
test_pipeline = Compose(cfg.data.test.pipeline) | ||
data = {'img': np.zeros((h, w, 3), dtype=np.uint8)} | ||
data = test_pipeline(data) | ||
img_metas = data['img_metas'][0].data | ||
return img_metas | ||
|
||
|
||
def process_img(frame_resize, img_metas, device): | ||
assert frame_resize.shape == img_metas['pad_shape'] | ||
frame_cuda = torch.from_numpy(frame_resize).to(device).float() | ||
frame_cuda = frame_cuda.permute(2, 0, 1) # HWC to CHW | ||
mean = torch.from_numpy(img_metas['img_norm_cfg']['mean']).to(device) | ||
std = torch.from_numpy(img_metas['img_norm_cfg']['std']).to(device) | ||
frame_cuda = F.normalize(frame_cuda, mean=mean, std=std, inplace=True) | ||
frame_cuda = frame_cuda[None, :, :, :] # NCHW | ||
data = {'img': [frame_cuda], 'img_metas': [[img_metas]]} | ||
return data | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
assert args.out or args.show, \ | ||
('Please specify at least one operation (save/show the ' | ||
'video) with the argument "--out" or "--show"') | ||
|
||
model = init_detector(args.config, args.checkpoint, device=args.device) | ||
|
||
if args.nvdecode: | ||
VideoCapture = ffmpegcv.VideoCaptureNV | ||
else: | ||
VideoCapture = ffmpegcv.VideoCapture | ||
video_origin = VideoCapture(args.video) | ||
img_metas = prefetch_img_metas(model.cfg, | ||
(video_origin.width, video_origin.height)) | ||
resize_wh = img_metas['pad_shape'][1::-1] | ||
video_resize = VideoCapture( | ||
args.video, | ||
resize=resize_wh, | ||
resize_keepratio=True, | ||
resize_keepratioalign='topleft', | ||
pix_fmt='rgb24') | ||
video_writer = None | ||
if args.out: | ||
video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps) | ||
|
||
with torch.no_grad(): | ||
for frame_resize, frame_origin in zip( | ||
mmcv.track_iter_progress(video_resize), video_origin): | ||
data = process_img(frame_resize, img_metas, args.device) | ||
result = model(return_loss=False, rescale=True, **data)[0] | ||
frame_mask = model.show_result( | ||
frame_origin, result, score_thr=args.score_thr) | ||
if args.show: | ||
cv2.namedWindow('video', 0) | ||
mmcv.imshow(frame_mask, 'video', args.wait_time) | ||
if args.out: | ||
video_writer.write(frame_mask) | ||
|
||
if video_writer: | ||
video_writer.release() | ||
video_origin.release() | ||
video_resize.release() | ||
|
||
cv2.destroyAllWindows() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters