Skip to content

Commit

Permalink
nll computation
Browse files Browse the repository at this point in the history
  • Loading branch information
XavierXiao authored May 11, 2021
1 parent 32cb67c commit 5517c9b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions test_inputcomplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def KL_div(mu,logvar,reduction = 'none'):

def store_NLL(x, recon, mu, logvar, z):
with torch.no_grad():
sigma = torch.exp(0.5*logvar)
b = x.size(0)
target = Variable(x.data.view(-1) * 255).long()
recon = recon.contiguous()
Expand All @@ -41,7 +42,7 @@ def store_NLL(x, recon, mu, logvar, z):
log_p_x_z = -torch.sum(cross_entropy.view(b ,-1), 1)

log_p_z = -torch.sum(z**2/2+np.log(2*np.pi)/2,1)
z_eps = z - mu
z_eps = (z - mu)/sigma
z_eps = z_eps.view(opt.repeat,-1)
log_q_z_x = -torch.sum(z_eps**2/2 + np.log(2*np.pi)/2 + logvar/2, 1)

Expand Down Expand Up @@ -245,4 +246,4 @@ def compute_NLL(weights):
break
difference_ood = np.asarray(difference_ood)

np.save('./array/complexity/difference_ood.npy', difference_ood)
np.save('./array/complexity/difference_ood.npy', difference_ood)

0 comments on commit 5517c9b

Please sign in to comment.