Skip to content

Commit

Permalink
replace ints with int8s
Browse files Browse the repository at this point in the history
  • Loading branch information
jonpvandermause committed Sep 15, 2024
1 parent f8ba2c0 commit 26943d9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion flare/utils/parameter_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def summarize_group(self, group_type):
for ele in self.groups["specie"][idt]:
atom_n = atomic_numbers[ele]
if atom_n >= len(self.species_mask):
new_mask = np.ones(atom_n, dtype=np.int) * (nspecie - 1)
new_mask = np.ones(atom_n, dtype=np.int8) * (nspecie - 1)
new_mask[: len(self.species_mask)] = self.species_mask
self.species_mask = new_mask
self.species_mask[atom_n] = idt
Expand Down
2 changes: 1 addition & 1 deletion flare/utils/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def check_instantiation(hyps, cutoffs, kernels, param_dict):

# check mask has the right dimension and values
mask = param_dict[f"{kernel}_mask"]
param_dict[f"{kernel}_mask"] = nparray(mask, dtype=np.int)
param_dict[f"{kernel}_mask"] = nparray(mask, dtype=np.int8)

assert npmax(mask) < n
dim = Parameters.ndim[kernel]
Expand Down

0 comments on commit 26943d9

Please sign in to comment.