diff --git a/test_inputcomplexity.py b/test_inputcomplexity.py index b90d86b..ff93e5f 100644 --- a/test_inputcomplexity.py +++ b/test_inputcomplexity.py @@ -33,6 +33,7 @@ def KL_div(mu,logvar,reduction = 'none'): def store_NLL(x, recon, mu, logvar, z): with torch.no_grad(): + sigma = torch.exp(0.5*logvar) b = x.size(0) target = Variable(x.data.view(-1) * 255).long() recon = recon.contiguous() @@ -41,7 +42,7 @@ def store_NLL(x, recon, mu, logvar, z): log_p_x_z = -torch.sum(cross_entropy.view(b ,-1), 1) log_p_z = -torch.sum(z**2/2+np.log(2*np.pi)/2,1) - z_eps = z - mu + z_eps = (z - mu)/sigma z_eps = z_eps.view(opt.repeat,-1) log_q_z_x = -torch.sum(z_eps**2/2 + np.log(2*np.pi)/2 + logvar/2, 1) @@ -245,4 +246,4 @@ def compute_NLL(weights): break difference_ood = np.asarray(difference_ood) - np.save('./array/complexity/difference_ood.npy', difference_ood) \ No newline at end of file + np.save('./array/complexity/difference_ood.npy', difference_ood)