Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformed distribution under Planar flow -> not integrating to 1 #1

Open
vr308 opened this issue Jun 28, 2020 · 1 comment
Open

Comments

@vr308
Copy link

vr308 commented Jun 28, 2020

Hi there!

Thanks for your excellent tutorial on NFs and the code. While playing one the one thing I wanted to do was just generate transformed pdfs and check if they integrate to 1. Using your planar flows class and forward method I wrote a code snippet to test this with np.trapz().

import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import pyro.distributions as dist

class Planar(nn.Module):
    def __init__(self, size=1, init_sigma=0.01):
        super().__init__()
        self.u = nn.Parameter(torch.randn(1, size).normal_(0, init_sigma))
        self.w = nn.Parameter(torch.randn(1, size).normal_(0, init_sigma))
        self.b = nn.Parameter(torch.zeros(1))

    @property
    def normalized_u(self):
        """
        Needed for invertibility condition.

        See Appendix A.1
        Rezende et al. Variational Inference with Normalizing Flows
        https://arxiv.org/pdf/1505.05770.pdf
        """

        # softplus
        def m(x):
            return -1 + torch.log(1 + torch.exp(x))

        wtu = torch.matmul(self.w, self.u.t())
        w_div_w2 = self.w / torch.norm(self.w)
        return self.u + (m(wtu) - wtu) * w_div_w2

    def psi(self, z):
        """
        ψ(z) =h′(w^tz+b)w

        See eq(11)
        Rezende et al. Variational Inference with Normalizing Flows
        https://arxiv.org/pdf/1505.05770.pdf
        """
        return self.h_prime(z @ self.w.t() + self.b) @ self.w

    def h(self, x):
        return torch.tanh(x)

    def h_prime(self, z):
        return 1 - torch.tanh(z) ** 2

    def forward(self, z):
        if isinstance(z, tuple):
            z, accumulating_ldj = z
        else:
            z, accumulating_ldj = z, 0
        psi = self.psi(z)

        u = self.normalized_u

        # determinant of jacobian
        det = (1 + psi @ u.t())

        # log |det Jac|
        ldj = torch.log(torch.abs(det) + 1e-6)

        wzb = z @ self.w.t() + self.b

        fz = z + (u * self.h(wzb))

        return fz, ldj + accumulating_ldj
    

Perhaps I am missing something in the way I generate the pdf? Is there anything apart from the jacobian adjustment I need to worry about.?

if __name__ == '__main__':
   
    z0 = torch.rand((1000, 2))
    
    # define a meshgrid

    x1 = torch.tensor(data=np.linspace(-5,5,100))
    x2 = torch.tensor(data=np.linspace(-5,5,100))
    
    x1_s, x2_s = torch.meshgrid(x1, x2)
    x_field = torch.tensor(np.concatenate([x1_s[..., None], x2_s[..., None]], axis=-1)).float()
    
    # unit Gaussian base dist.

    base_dist = dist.MultivariateNormal(loc=torch.zeros(2), covariance_matrix=torch.eye(2))
    
    # Planar flow 

    pf = Planar(size=2)
        
    xk, ldj = pf.forward(x_field)
    
    # Generating pdf and checking if integrates to 1

    planar_pdf = torch.exp(base_dist.log_prob(x_field) - ldj.reshape(100,100))
    print(np.trapz(np.trapz(planar_pdf.detach(), torch.linspace(-7,7,100), axis=0), torch.linspace(-7,7,100)))`

@ritchie46
Copy link
Owner

Hi!

You define the field over the range -5..5

    x1 = torch.tensor(data=np.linspace(-5,5,100))
    x2 = torch.tensor(data=np.linspace(-5,5,100))

And next you integrate over the field -7..7

    print(np.trapz(np.trapz(planar_pdf.detach(), torch.linspace(-7,7,100), axis=0), torch.linspace(-7,7,100)))`

If you make the ranges equal you'll be approximating a sum of 1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants