Skip to content

Commit

Permalink
refactor: applying same conditional logic to pymatgen edges
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed Dec 20, 2024
1 parent 4a56a81 commit a40db35
Showing 1 changed file with 35 additions and 25 deletions.
60 changes: 35 additions & 25 deletions matsciml/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,32 +854,42 @@ def _all_sites_have_neighbors(neighbors):
raise ValueError(
f"No neighbors detected for structure with cutoff {cutoff}; {structure}"
)
keep = set()
# only keeps undirected edges that are unique through set
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
keep.add(
Edge(
src_idx,
site.index,
np.array(site.image),
is_undirected,
# if we assume undirected edges, apply a filter
if is_undirected:
keep = set()
# only keeps undirected edges that are unique through set
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
keep.add(
Edge(
src_idx,
site.index,
np.array(site.image),
is_undirected,
)
)
)
# now only keep the edges after the first loop
all_src, all_dst, all_images = [], [], []
num_atoms = len(structure.atomic_numbers)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# stop adding edges if either src/dst have accumulated enough neighbors
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1
# now only keep the edges after the first loop
all_src, all_dst, all_images = [], [], []
num_atoms = len(structure.atomic_numbers)
counter = {index: 0 for index in range(num_atoms)}
for edge in keep:
# stop adding edges if either src/dst have accumulated enough neighbors
if counter[edge.src] > max_neighbors or counter[edge.dst] > max_neighbors:
pass
else:
all_src.append(edge.src)
all_dst.append(edge.dst)
all_images.append(edge.image)
counter[edge.src] += 1
counter[edge.dst] += 1
# alternatively, just add the edges as is from pymatgen
else:
all_src, all_dst, all_images = [], [], []
for src_idx, dst_sites in enumerate(neighbors):
for site in dst_sites:
all_src.append(src_idx)
all_dst.append(site)
all_images.append(site.image)
if any([len(obj) == 0 for obj in [all_src, all_dst, all_images]]):
raise ValueError(
f"No images or edges to work off for cutoff {cutoff}."
Expand Down

0 comments on commit a40db35

Please sign in to comment.