Skip to content

Commit

Permalink
Merge pull request #151 from rs-station/dw_cpu
Browse files Browse the repository at this point in the history
Dw cpu resolves #150
  • Loading branch information
kmdalton authored Jan 18, 2024
2 parents ecdf8c0 + 486136a commit 39941c7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 8 additions & 1 deletion careless/models/priors/wilson.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,14 @@ def stddev(self):
return self.wilson_prior.stddev()

def log_prob(self, z):
z_parent = tf.gather(z, self.reflids, axis=-1)
mask = self.reflids >= 0
sanitized_reflids = tf.where(mask, self.reflids, 0)
z_parent = tf.where(
mask[None,:],
tf.gather(z, sanitized_reflids, axis=-1),
0.,
)

loc = tf.where(
self.absent,
0.,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def getVersionNumber():
install_requires=[
"reciprocalspaceship>=0.9.16",
"tqdm",
"tensorflow>=2.8",
"tensorflow<2.16",
"tensorflow-probability",
"matplotlib",
"seaborn",
Expand Down

0 comments on commit 39941c7

Please sign in to comment.