-
Notifications
You must be signed in to change notification settings - Fork 1
/
extract_ssim.py
232 lines (192 loc) · 8.64 KB
/
extract_ssim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from skimage.util.shape import view_as_blocks
from skimage import filters
import matplotlib.pyplot as plt
from joblib import dump,Parallel,delayed
from scipy.stats import gmean
import time
from scipy.ndimage import gaussian_filter
from utils.hdr_utils import hdr_yuv_read
from utils.csf_utils import csf_barten_frequency,csf_filter_block,blockwise_csf,windows_csf
import numpy as np
import glob
import pandas as pd
import os
from os.path import join
import scipy
import colour
import socket
import sys
import argparse
from datetime import datetime
import warnings
from ssim_features import structural_similarity_features as ssim_features
def global_exp(image,par):
assert len(np.shape(image)) == 2
avg = np.average(image)
y = np.exp(par*(image-avg))
return y
def gen_gauss_window(lw, sigma):
sd = np.float32(sigma)
lw = int(lw)
weights = [0.0] * (2 * lw + 1)
weights[lw] = 1.0
sum = 1.0
sd *= sd
for ii in range(1, lw + 1):
tmp = np.exp(-0.5 * np.float32(ii * ii) / sd)
weights[lw + ii] = tmp
weights[lw - ii] = tmp
sum += 2.0 * tmp
for ii in range(2 * lw + 1):
weights[ii] /= sum
return weights
def local_exp(image,par,patch_size):
assert len(np.shape(image)) == 2
h, w = np.shape(image)
avg_window = gen_gauss_window(patch_size//2, 7.0/6.0)
mu_image = np.zeros((h, w), dtype=np.float32)
image = np.array(image).astype('float32')
scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode='constant')
scipy.ndimage.correlate1d(mu_image, avg_window, 1, mu_image, mode='constant')
y = np.exp(par*(image - mu_image))
return y
def m_exp(image,par,patch_size = 31):
maxY = scipy.ndimage.maximum_filter(image,size=(patch_size,patch_size))
minY = scipy.ndimage.minimum_filter(image,size=(patch_size,patch_size))
image = -4+(image-minY)* 8/(1e-3+maxY-minY)
Y_transform = np.exp(np.abs(image)**par)-1
Y_transform[image<0] = -Y_transform[image<0]
return Y_transform
def global_m_exp(Y,delta):
Y = -4+(Y-np.amin(Y))* 8/(1e-3+np.amax(Y)-np.amin(Y))
Y_transform = np.exp(np.abs(Y)**delta)-1
Y_transform[Y<0] = -Y_transform[Y<0]
return Y_transform
def logit(Y,par):
maxY = scipy.ndimage.maximum_filter(Y,size=(31,31))
minY = scipy.ndimage.minimum_filter(Y,size=(31,31))
delta = par
Y_scaled = -0.99+1.98*(Y-minY)/(1e-3+maxY-minY)
Y_transform = np.log((1+(Y_scaled)**delta)/(1-(Y_scaled)**delta))
if(delta%2==0):
Y_transform[Y<0] = -Y_transform[Y<0]
return Y_transform
def global_logit(Y,par):
delta = par
Y_scaled = -0.99+1.98*(Y-np.amin(Y))/(1e-3+np.amax(Y)-np.amin(Y))
Y_transform = np.log((1+(Y_scaled)**delta)/(1-(Y_scaled)**delta))
if(delta%2==0):
Y_transform[Y<0] = -Y_transform[Y<0]
return Y_transform
def ssim_refall_wrapper(ind):
dis_f = files[ind]
ref_f = ref_names[ind]
print(dis_f,ref_f)
dis_f = os.path.join(vid_pth,dis_f)
ref_f = os.path.join(vid_pth,ref_f)
ssim_video_wrapper(ref_f,dis_f,ind)
def ssim_video_wrapper(ref_f,dis_f,dis_index):
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print("Current Time =", current_time)
basename = os.path.basename(dis_f)
print(basename)
if(ref_f==dis_f):
print('Videos are the same')
return
h = 2160 #hs[dis_index]
w = 3840 #ws[dis_index]
if args.frame_range == 'all':
start = 0
# get the number of frames using file size
dis_num_frames = os.path.getsize(dis_f) // (h * w * 3)
ref_num_frames = os.path.getsize(ref_f) // (h * w * 3)
if dis_num_frames != ref_num_frames:
# throw a warning
warnings.warn('The number of frames in the reference and distorted videos are not the same. The smaller of the two will be used.')
end = min(dis_num_frames, ref_num_frames)
else:
start = start_list[dis_index]
end = end_list[dis_index]
ssim_image_wrapper(ref_f,dis_f,start,end,h,w,space = args.space, channel = args.channel, ind = dis_index)
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print("Current Time =", current_time)
def ssim_image_wrapper(ref_f,dis_f,start,end,h,w,space, channel ,ind):
ref_file_object = open(ref_f)
dis_file_object = open(dis_f)
framelist = list(range(start,end,int(fps[ind])))
print(f'Extracting frames from {start} to {end}')
dis_name = os.path.splitext(os.path.basename(dis_f))[0]
output_csv_ssim = os.path.join(out_pth_ssim, dis_name+'.csv')
ssim_feats = []
for framenum in framelist:
try:
ref_multichannel = hdr_yuv_read(ref_file_object,framenum,h,w)
dis_multichannel = hdr_yuv_read(dis_file_object,framenum,h,w)
except Exception as e:
print(e)
break
if (space == 'ycbcr'):
ref_multichannel = [i.astype(np.float64)/1023 for i in ref_multichannel]
dis_multichannel = [i.astype(np.float64)/1023 for i in dis_multichannel]
elif(space == 'lab'):
#first convert to 0-1 scale for the conversion
ref_multichannel = np.stack(ref_multichannel,axis = 2)
dis_multichannel = np.stack(dis_multichannel,axis = 2)
ref_multichannel = ref_multichannel.astype(np.float64)/1023
dis_multichannel = dis_multichannel.astype(np.float64)/1023
frame = colour.YCbCr_to_RGB(ref_multichannel,K = [0.2627,0.0593])
xyz = colour.RGB_to_XYZ(frame, [0.3127,0.3290], [0.3127,0.3290],
colour.models.RGB_COLOURSPACE_BT2020.RGB_to_XYZ_matrix,
chromatic_adaptation_transform='CAT02',
cctf_decoding=colour.models.eotf_PQ_BT2100)/10000
lab = colour.XYZ_to_hdr_CIELab(xyz, illuminant=[ 0.3127, 0.329 ], Y_s=0.2, Y_abs=100, method='Fairchild 2011')
ref_multichannel = lab
frame = colour.YCbCr_to_RGB(dis_multichannel,K = [0.2627,0.0593])
xyz = colour.RGB_to_XYZ(frame, [0.3127,0.3290], [0.3127,0.3290],
colour.models.RGB_COLOURSPACE_BT2020.RGB_to_XYZ_matrix,
chromatic_adaptation_transform='CAT02',
cctf_decoding=colour.models.eotf_PQ_BT2100)/10000
lab = colour.XYZ_to_hdr_CIELab(xyz, illuminant=[ 0.3127, 0.329 ], Y_s=0.2, Y_abs=100, method='Fairchild 2011')
dis_multichannel = lab
ref_multichannel = ref_multichannel.transpose(2,0,1)
dis_multichannel = dis_multichannel.transpose(2,0,1)
ref_singlechannel = ref_multichannel[channel]
dis_singlechannel = dis_multichannel[channel]
ssim_feat = ssim_features(ref_singlechannel, dis_singlechannel)
ssim_feats.append(ssim_feat)
ssim_feats = np.array(ssim_feats)
# average over the frames
ssim_feats = np.mean(ssim_feats, axis=0)
# create a dataframe and save it
df = pd.DataFrame(ssim_feats.reshape(1, -1))
df.to_csv(output_csv_ssim, index=False)
parser = argparse.ArgumentParser()
parser.add_argument('vid_pth', type=str, help='directory containing reference data')
parser.add_argument('feature_path', type=str, help='directory containing distorted data')
parser.add_argument('csv_file_vidinfo', type=str, help='csv_file_vidinfo')
parser.add_argument("--space",help="choose which color space. Support 'ycbcr' and 'lab'.")
parser.add_argument("--channel",help="indicate which channel to process. Please provide 0, 1, or 2",type=int)
parser.add_argument("--njobs", help="Number of videos processed at the same time.",type=int,default=1)
parser.add_argument("--frame_range", type=str, default='all',
help="frame range to process. 'all' or 'file'. if 'all', the whole video is used to estimate the quality. if 'file', the video uses the 'start' and 'end' columns in the csv file to estimate the quality.")
args = parser.parse_args()
print(args.space)
csv_file_vidinfo = args.csv_file_vidinfo
vid_pth = args.vid_pth
feature_path = args.feature_path
njobs = args.njobs
df_vidinfo = pd.read_csv(csv_file_vidinfo)
files = df_vidinfo["encoded_yuv"]
fps = df_vidinfo["fps"]
if args.frame_range == 'file':
try:
start_list = df_vidinfo["start"]
end_list = df_vidinfo["end"]
except:
raise ValueError("Please provide 'start' and 'end' columns in the csv file when --frame_range is 'file'.")
ref_names = df_vidinfo["ref_yuv"]
out_pth_ssim = join(feature_path,'hdrssimnew')
os.makedirs(out_pth_ssim, exist_ok=True)
Parallel(n_jobs=njobs,verbose=1,backend="multiprocessing")(delayed(ssim_refall_wrapper)(i) for i in range(len(files)))