Skip to content

Commit

Permalink
Fixes minor basis msonable issues and test_cofe/test_expansion/test_p…
Browse files Browse the repository at this point in the history
…rune. All tests in test_cofe now pass.
  • Loading branch information
qchempku2017 committed Jan 7, 2025
1 parent 4416093 commit 42ea6fc
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 11 deletions.
23 changes: 17 additions & 6 deletions smol/cofe/space/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def __init__(self, site_space, basis_functions):
f"species {site_space} in the site space provided."
)

self._f_array = self._construct_function_array(basis_functions)
# Enforce float64 type to guarantee compatibility.
self._f_array = np.array(
self._construct_function_array(basis_functions), dtype=np.float64
)

@abstractmethod
def _construct_function_array(self, basis_functions):
Expand Down Expand Up @@ -211,8 +214,7 @@ def _construct_function_array(self, basis_functions):
[[function(sp) for sp in self.species] for function in nconst_functions]
)
# stack the constant basis function on there for proper normalization
# Enforce float64 type to guarantee compatibility.
return np.vstack((np.ones_like(func_arr[0]), func_arr)).astype(np.float64)
return np.vstack((np.ones_like(func_arr[0]), func_arr))

@property
def function_array(self):
Expand Down Expand Up @@ -251,6 +253,10 @@ def orthonormalize(self):
self._r_array = q_mat[:, 0] / np.sqrt(self.measure_vector) * r_mat.T
self._f_array = q_mat.T / q_mat[:, 0] # make first row constant = 1

# Enforce float64 to guarantee compatibility.
self._r_array = self._r_array.astype(np.float64)
self._f_array = self._f_array.astype(np.float64)

def rotate(self, angle, index1=0, index2=1):
"""Rotate basis functions about subspace spanned by two vectors.
Expand Down Expand Up @@ -354,11 +360,16 @@ def from_dict(cls, d):
site_space = SiteSpace.from_dict(d["site_space"])
site_basis = basis_factory(d["flavor"], site_space)
# restore arrays
site_basis._f_array = np.array(d["func_array"])
site_basis._r_array = np.array(d["orthonorm_array"])
site_basis._f_array = np.array(d["func_array"]).astype(np.float64)
site_basis._r_array = (
None
if d["orthonorm_array"] is None
else np.array(d["orthonorm_array"]).astype(np.float64)
)
rot_array = d.get("rot_array")
if rot_array is not None:
site_basis._rot_array = np.array(rot_array)
# Enforce float64 to ensure compatibility
site_basis._rot_array = np.array(rot_array).astype(np.float64)
return site_basis


Expand Down
3 changes: 2 additions & 1 deletion smol/cofe/space/orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def correlation_tensors(self):
def flat_correlation_tensors(self):
"""Get correlation_tensors flattened to 2D for fast cython."""
if self._flat_corr_tensors is None:
# Enforce float64 to ensure compatibility.
self._flat_corr_tensors = np.ascontiguousarray(
np.reshape(
self.correlation_tensors,
Expand All @@ -256,7 +257,7 @@ def flat_correlation_tensors(self):
),
order="C",
)
)
).astype(np.float64)
return self._flat_corr_tensors

@property
Expand Down
3 changes: 2 additions & 1 deletion tests/test_cofe/test_clusterspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,8 @@ def test_site_basis_rotation(cluster_subspace, rng):


def _encode_occu(occu, bits):
return np.array([bit.index(sp) for sp, bit in zip(occu, bits)])
# Enforce int32 to ensure compatibility.
return np.array([bit.index(sp) for sp, bit in zip(occu, bits)], dtype=np.int32)


def test_vs_CASM_pairs(single_structure):
Expand Down
12 changes: 9 additions & 3 deletions tests/test_cofe/test_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,15 @@ def test_prune(cluster_expansion):
for s, m in zip(cluster_expansion.structures, cluster_expansion.scmatrices)
]
)
preds = np.sum(
cluster_expansion.cluster_subspace.orbit_multiplicities * ints, axis=1
)

# 01/07/2024 Fengyu Xie: using the original cluster_expansion.cluster_subspace is
# not safe, as pruning may occasionally prune out an entire orbit, causing
# inconsistency between ints.shape and orbit_multiplicities.shape. Fixed by replacing
# with expansion.cluster_subspace.orbit_multiplicities.
# preds = np.sum(
# cluster_expansion.cluster_subspace.orbit_multiplicities * ints, axis=1
# )
preds = np.sum(expansion.cluster_subspace.orbit_multiplicities * ints, axis=1)
npt.assert_allclose(preds, np.dot(pruned_feat_matrix, new_coefs), atol=1e-5)


Expand Down
5 changes: 5 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def assert_msonable(obj, skip_keys=None, test_if_subclass=True):
for key in d1.keys():
if key in skip_keys:
continue
# if d1[key] != d2[key]:
# print("key:", key)
# print("d1:", d1[key])
# print("d2:", d2[key])

assert d1[key] == d2[key]

try:
Expand Down

0 comments on commit 42ea6fc

Please sign in to comment.