Skip to content

Commit

Permalink
Merge pull request #40 from minaskar/dev
Browse files Browse the repository at this point in the history
Version 1.0.1
  • Loading branch information
minaskar authored Jan 29, 2024
2 parents 61d7e0f + 0a4d589 commit 84c157f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pocomc/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.0.0"
version = "1.0.1"
18 changes: 11 additions & 7 deletions pocomc/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def __init__(self, n_dim, flow=None):
residual=True,)
else:
self.flow = flow
self.transform = self.flow().transform

@property
def transform(self):
"""
Transformation object.
"""
return self.flow().transform

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down Expand Up @@ -86,9 +92,8 @@ def inverse(self, u: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Transformed samples in the original space with the same shape as the latent space inputs.
"""
u = torch_double_to_float(u)
x = self.transform.inv(u)
logdetj = self.transform.log_abs_det_jacobian(x, None)
return x, -logdetj
x, logdetj = self.transform.inv.call_and_ladj(u)
return x, logdetj

def log_prob(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -118,9 +123,8 @@ def sample(self, size: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
samples, log_prob : ``tuple``
Samples as a ``torch.Tensor`` with shape ``(size, n_dimensions)`` and log probability values with shape ``(size, )``.
"""
u = torch.randn(size, self.n_dim)
x = self.transform.inv(u)
return x, self.flow().log_prob(x)
x, log_p = self.flow().rsample_and_log_prob((size,))
return x, log_p

def fit(self,
x,
Expand Down

0 comments on commit 84c157f

Please sign in to comment.