diff --git a/careless/models/priors/wilson.py b/careless/models/priors/wilson.py index 18af7d3..dfa93a0 100644 --- a/careless/models/priors/wilson.py +++ b/careless/models/priors/wilson.py @@ -136,7 +136,14 @@ def stddev(self): return self.wilson_prior.stddev() def log_prob(self, z): - z_parent = tf.gather(z, self.reflids, axis=-1) + mask = self.reflids >= 0 + sanitized_reflids = tf.where(mask, self.reflids, 0) + z_parent = tf.where( + mask[None,:], + tf.gather(z, sanitized_reflids, axis=-1), + 0., + ) + loc = tf.where( self.absent, 0., diff --git a/setup.py b/setup.py index 579ddbc..3da6cb2 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def getVersionNumber(): install_requires=[ "reciprocalspaceship>=0.9.16", "tqdm", - "tensorflow>=2.8", + "tensorflow<2.16", "tensorflow-probability", "matplotlib", "seaborn",