Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

质量评价模型合并 #90

Merged
merged 19 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Binary file not shown.
8 changes: 8 additions & 0 deletions bmf/demo/video_quality_assessment/module_utils/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging


def get_logger():
return logging.getLogger('main')
109 changes: 109 additions & 0 deletions bmf/demo/video_quality_assessment/module_utils/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import sys
if sys.version_info.major == 2:
from Queue import *
else:
from queue import *
import bmf

#from module_utils.util import generate_out_packets


def generate_out_packets(packet, np_arr, out_fmt):
video_frame = bmf.VideoFrame.from_ndarray(np_arr, format=out_fmt)
video_frame.pts = packet.get_data().pts
video_frame.time_base = packet.get_data().time_base

pkt = bmf.Packet()
pkt.set_timestamp(packet.get_timestamp())
pkt.set_data(video_frame)
return pkt


class SyncModule(bmf.Module):
def __init__(self, node=None, nb_in=1, in_fmt='yuv420p', out_fmt='yuv420p'):
"""
nb_in: the number of frames for core_process function
in_fmt: the pixel format of frames for core_process function
out_fmt: the pixel format of frame returned by core_process function
"""
self._node = node

self._margin_num = (nb_in - 1) // 2
self._out_frame_index = self._margin_num
self._in_frame_num = nb_in

self._in_fmt = in_fmt
self._out_fmt = out_fmt

self._in_packets = []
self._frames = []
self._eof = False

def process(self, task):
print(task.get_inputs().items(),'####',task.get_outputs().items())
input_queue = task.get_inputs()[0]
output_queue = task.get_outputs()[0]

while not input_queue.empty():
pkt = input_queue.get()
pkt_timestamp = pkt.get_timestamp()
pkt_data = pkt.get_data()
print('##',pkt_data)


if pkt_timestamp == bmf.Timestamp.EOF:
self._eof = True
if pkt_data is not None:
self._in_packets.append(pkt)
self._frames.append(pkt.get_data().to_ndarray(format=self._in_fmt))

# padding first frame.
if len(self._in_packets) == 1:
for _ in range(self._margin_num):
self._in_packets.append(self._in_packets[0])
self._frames.append(self._frames[0])

if self._eof:
#print(self._in_packets, self._frames)
# padding last frame.
for _ in range(self._margin_num):
self._in_packets.append(self._in_packets[-1])
self._frames.append(self._frames[-1])
self._consume(output_queue)

output_queue.put(bmf.Packet.generate_eof_packet())
task.set_timestamp(bmf.Timestamp.DONE)

return bmf.ProcessResult.OK

def _consume(self, output_queue):
while len(self._in_packets) >= self._in_frame_num:
out_frame = self.core_process(self._frames[:self._in_frame_num])
out_packet = generate_out_packets(self._in_packets[self._out_frame_index], out_frame, self._out_fmt)
output_queue.put(out_packet)
self._in_packets.pop(0)
self._frames.pop(0)

def core_process(self, frames):
"""
user defined, process frames to output one frame, pass through by default
frames: input frames, list format
"""
return frames[0]

def clean(self):
pass

def close(self):
self.clean()

def reset(self):
self._eof = False








17 changes: 17 additions & 0 deletions bmf/demo/video_quality_assessment/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
### bmf- video_quality_assessment demo
!!!!!!!! Before running, you need to create models dir and download the onnx file to the created dir. The onnx file path is releases/download/files/vqa_4kpgc_1.onnx

1. Algorithm introduction:
The vqa model evaluates video quality for 4kpgc scenes, and carries out targeted algorithm design and optimization for specific spatiotemporal distortion characteristics.The model is trained and tested based on NAIC(National Artifical Intelligence Challenge) competition data. The average value of the model on srcc, plcc and ur(usability ratio) reaches 0.855, winning the championship of the AI+ Video Quality evaluation circuit of the National Artificial Intelligence Competition. Details: https://naic.pcl.ac.cn/contest/17/53

2. Get Started quickly:
By running vqa_demo.py file, you can quickly complete the experience of video quality assessment. The quality score ranges from 0 to 10. The higher the score, the lower the distortion degree and the better the video quality.

4. Requirements
os
bmf
sys
time
json
numpy
onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"vqa_4kpgc": 6.68,
"vqa_4kpgc_version": "v1.0",
"num_crop": 5,
"cau_frames_num": 1
}
105 changes: 105 additions & 0 deletions bmf/demo/video_quality_assessment/vqa_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import bmf
import os
import sys
import cv2

def get_width_and_height(video_path):
capture = cv2.VideoCapture(video_path)
height = capture.get(cv2.CAP_PROP_FRAME_HEIGHT)
width = capture.get(cv2.CAP_PROP_FRAME_WIDTH)
capture.release()
return int(width), int(height)

def get_duration(video_path):
capture = cv2.VideoCapture(video_path)
fps = capture.get(cv2.CAP_PROP_FPS) # OpenCV2 version 2 used "CV_CAP_PROP_FPS"
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
duration = frame_count / fps
capture.release()
return duration

def segment_decode_ticks(video_path, seg_dur=4.0, lv1_dur_thres=24.0, max_dur=1000):
'''
bmf module new decode duration ticks
- 0 < Duration <= 24s, 抽帧间隔r=1, 抽帧0~24帧
- 24s < Duration <= 600s 分片抽取, 抽帧间隔r=1, 抽帧24帧
- 6个4s切片, 共计6x4=24帧
- duration > 600s, 分8片抽帧r=1, 抽帧数量32帧
- (600, inf), 8个4s切片, 共计8x4=32帧
最大解码长度 max_dur: 1000s
'''
duration = get_duration(video_path)
duration_ticks = []
if duration < lv1_dur_thres:
return dict()
elif duration <= 600: # medium duration
seg_num = 6
seg_intev = (duration - seg_num * seg_dur) / (seg_num - 1)
if seg_intev < 0.5:
duration_ticks.extend([0, duration])
else:
for s_i in range(seg_num):
seg_init = s_i * (seg_dur + seg_intev)
seg_end = seg_init + seg_dur
duration_ticks.extend([round(seg_init, 3), round(seg_end, 3)])
else: # long duration
seg_num = 8
seg_intev = (min(duration, max_dur) - seg_num * seg_dur) / (seg_num - 1)
for s_i in range(seg_num):
seg_init = s_i * (seg_dur + seg_intev)
seg_end = seg_init + seg_dur
duration_ticks.extend([round(seg_init, 3), round(seg_end, 3)])
return {'durations': duration_ticks}


if __name__=='__main__':
input_path='files/VD_0290_00405.png'
out_path='result/20f72ebc978c4b06830e23adee6b6ff7.json'

# check input path
if not os.path.exists(input_path):
print(
"please download input first, use 'wget https://github.com/BabitMF/bmf/releases/download/files/files.tar.gz && tar zxvf files.tar.gz' "
)
exit(0)

# check model path
model_path = "models/vqa_4kpgc_1.onnx"
if not os.path.exists(model_path):
print(
"please download model first, use 'wget https://github.com/BabitMF/bmf/releases/download/files/models.tar.gz && tar zxvf models.tar.gz' "
)
exit(0)


option = dict()
option['output_path'] = out_path
option['width'], option['height'] = get_width_and_height(input_path)

duration_segs = segment_decode_ticks(input_path)
decode_params = {'input_path': input_path,
'video_params': {'extract_frames': {'fps': 1}}}
decode_params.update(duration_segs)

# module process
streams = bmf.graph().decode(decode_params)
py_module_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
py_entry = 'vqa_module.BMFVQA_4kpgc'
video_stream = streams['video'].module('vqa_4kpgc_module', option,
py_module_path, py_entry)
video_stream.upload().run()













120 changes: 120 additions & 0 deletions bmf/demo/video_quality_assessment/vqa_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-


from module_utils.template import SyncModule
from module_utils.logger import get_logger
import os
import time
import json
import os.path as osp
import numpy as np
import onnxruntime as ort

#设置torch占用的内核数
os.environ["OMP_NUM_THREADS"] = "8"
LOGGER = get_logger()



def random_crop(d_img, crop_size):
b, c, h, w = d_img.shape
top = np.random.randint(0, h - crop_size)
left = np.random.randint(0, w - crop_size)
crop_img = d_img[:, :, top:top + crop_size, left:left + crop_size]
return crop_img

def crop_for_video(img, crop_size, num_crop):
for i in range(num_crop):
if i==0:
crop_img=random_crop(img,crop_size)
else:
crop_img=np.concatenate(( crop_img,random_crop(img,crop_size) ), axis=0)
return crop_img




class VQA_4kpgc:
def __init__(self, output_path, model_version=1, width=224, height=224, num_crop=5, caiyangjiange=100):
self._frm_idx = 0
self._frm_scores = []
self._output_path = output_path
self._model_version = model_version

self.num_crop=num_crop
self.resize_reso = [width, height]
self.caiyangjiange = caiyangjiange

model_dir = osp.join(osp.abspath(osp.dirname(__file__)), 'models')
vqa_4kpgc_model_path = osp.realpath(osp.join(model_dir, 'vqa_4kpgc_1.onnx'))
self.ort_session = ort.InferenceSession(vqa_4kpgc_model_path)
self.input_node=self.ort_session.get_inputs()[0]

LOGGER.info("create AdvColor model [CPU]")


def preprocess(self, frame):
frame = (frame.astype(np.float32) / 255.0 - np.array([0.5, 0.5, 0.5], dtype='float32')) / \
(np.array([0.5, 0.5, 0.5], dtype='float32'))
frame = np.transpose(frame, (2, 0, 1))
frame = np.expand_dims(frame, 0)
frame = crop_for_video(frame, self.resize_reso[0], self.num_crop)
return frame


@staticmethod
def score_pred_mapping(preds):
max=9.8
min=0.111111111111111
pred_score=preds*(max-min)+min
return pred_score


def process(self, frames):
self._frm_idx += 1
#对同一视频间隔一定帧计算一次 OR 采用module_utils中的segment_decode_ticks函数进行解码
#if (self._frm_idx-1)%self.caiyangjiange==0:
frames = [frame if frame.flags["C_CONTIGUOUS"] else np.ascontiguousarray(frame) for frame in frames]
frame = self.preprocess(frames[0])
if not frame.flags['C_CONTIGUOUS']:
frame = np.ascontiguousarray(frame, dtype=np.float32)

t1 = time.time()
raw_score = self.ort_session.run(None, {self.input_node.name: frame})[0].mean()
score = self.score_pred_mapping(raw_score)
self._frm_scores.append(score)
t2 = time.time()
LOGGER.info(f'[vqa_4kpgc] inference time: {(t2 - t1)*1000:0.1f} ms')

return frames[0]


def clean(self):
nr_score = round(np.mean(self._frm_scores), 2)
results = {'vqa_4kpgc': nr_score, 'vqa_4kpgc_version': self._model_version, 'num_crop': self.num_crop, 'cau_frames_num': self._frm_idx}
LOGGER.info(f'overall prediction {json.dumps(results)}')
with open(self._output_path, 'w') as outfile:
json.dump(results, outfile, indent=4, ensure_ascii=False)





class BMFVQA_4kpgc(SyncModule):
def __init__(self, node=None, option=None):
height = option.get('height', 0)
width = option.get('width', 0)
output_path = option.get('output_path', 0)
model_version = option.get('model_version', 'v1.0')
self._nrp = VQA_4kpgc(output_path=output_path, model_version=model_version, width=224, height=224, num_crop=5, caiyangjiange=100)
SyncModule.__init__(self, node, nb_in=1, in_fmt='rgb24', out_fmt='rgb24')

def core_process(self, frames):
return self._nrp.process(frames)

def clean(self):
self._nrp.clean()