diff --git a/jarvis/core/lattice.py b/jarvis/core/lattice.py index a7d21802..03befb80 100644 --- a/jarvis/core/lattice.py +++ b/jarvis/core/lattice.py @@ -63,7 +63,7 @@ def lat_lengths(self): """Return lattice vectors' lengths.""" return [ round(i, 6) - for i in (np.sqrt(np.sum(self._lat ** 2, axis=1)).tolist()) + for i in (np.sqrt(np.sum(self._lat**2, axis=1)).tolist()) ] # return [round(np.linalg.norm(v), 6) for v in self._lat] @@ -239,7 +239,9 @@ def reciprocal_lattice_crystallographic(self): """Return reciprocal Lattice without 2 * pi.""" return Lattice(self.reciprocal_lattice().matrix / (2 * np.pi)) - def get_points_in_sphere(self, frac_points, center, r): + def get_points_in_sphere( + self, frac_points, center, r, distance_vector=True + ): """ Find all points within a sphere from the point. @@ -283,11 +285,21 @@ def get_points_in_sphere(self, frac_points, center, r): tmp_img = cart_images[None, :, :, :, :] coords = cart_coords[:, None, None, None, :] + tmp_img coords -= center[None, None, None, None, :] + dist_vect = coords coords **= 2 d_2 = np.sum(coords, axis=4) # Determine which points are within `r` of `center` - within_r = np.where(d_2 <= r ** 2) + within_r = np.where(d_2 <= r**2) + if distance_vector: + return ( + shifted_coords[within_r], + np.sqrt(d_2[within_r]), + indices[within_r[0]], + images[within_r[1:]], + dist_vect[within_r], + ) + return ( shifted_coords[within_r], np.sqrt(d_2[within_r]), @@ -316,9 +328,7 @@ def find_all_matches(self, other_lattice, ltol=1e-5, atol=1): ] c_a, c_b, c_c = (cart[i] for i in inds) f_a, f_b, f_c = (frac[i] for i in inds) - l_a, l_b, l_c = ( - np.sum(c ** 2, axis=-1) ** 0.5 for c in (c_a, c_b, c_c) - ) + l_a, l_b, l_c = (np.sum(c**2, axis=-1) ** 0.5 for c in (c_a, c_b, c_c)) def get_angles(v1, v2, l1, l2): x = np.inner(v1, v2) / l1[:, None] / l2