Skip to content

Commit

Permalink
flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
ae-foster committed Oct 23, 2018
1 parent d138046 commit a9e3b4a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
3 changes: 2 additions & 1 deletion pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
"TorchDistribution",
"VonMises",
"VonMises3D",
"ZeroInflatedPoisson"
"ZeroInflatedPoisson",
"CensoredDistribution"
]

# Import all torch distributions from `pyro.distributions.torch_distribution`
Expand Down
8 changes: 3 additions & 5 deletions pyro/distributions/censored.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch.distributions import constraints
from torch.distributions.utils import _sum_rightmost

from pyro.distributions.torch_distribution import TorchDistribution

Expand All @@ -9,7 +8,8 @@ class CensoredDistribution(TorchDistribution):

def __init__(self, base_distribution, upper_lim=float('inf'), lower_lim=float('-inf'), validate_args=None):
# Log-prob only computed correctly for univariate base distribution
assert base_distribution.event_dim == 0 or base_distribution.event_dim == 1 and base_distribution.event_shape[0] == 1
assert base_distribution.event_dim == 0 or (
base_distribution.event_dim == 1 and base_distribution.event_shape[0] == 1)
self.base_dist = base_distribution
self.upper_lim = upper_lim
self.lower_lim = lower_lim
Expand All @@ -33,7 +33,6 @@ def rsample(self, sample_shape=torch.Size()):
x[x > self.upper_lim] = self.upper_lim
x[x < self.lower_lim] = self.lower_lim


def log_prob(self, value):
"""
Scores the sample by giving a probability density relative to a new base measure.
Expand All @@ -44,7 +43,7 @@ def log_prob(self, value):
as for discrete distributions. `log_prob(x)` in the interior represent regular
pdfs with respect to Lebesgue measure on R.
**Note**: `log_prob` scores from distributions with different censoring are not
**Note**: `log_prob` scores from distributions with different censoring are not
comparable.
"""
log_prob = self.base_dist.log_prob(value)
Expand All @@ -68,4 +67,3 @@ def cdf(self, value):
def icdf(self, value):
# Is this even possible?
raise NotImplemented

0 comments on commit a9e3b4a

Please sign in to comment.