Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Commit

Permalink
isolate get_atom_features
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi Chen committed Dec 16, 2019
1 parent 9da04cf commit 0a6db3f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions megnet/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def convert(self, structure: Structure, state_attributes: List = None) -> Dict:
if np.size(np.unique(index1)) < len(atoms):
raise RuntimeError("Isolated atoms found in the structure")
else:
return {'atom': np.array(atoms, dtype='int32').tolist(),
return {'atom': atoms,
'bond': bonds,
'state': state_attributes,
'index1': index1,
Expand All @@ -118,7 +118,8 @@ def get_atom_features(structure) -> List[int]:
Returns:
List of atomic numbers
"""
return [i.specie.Z for i in structure]
return np.array([i.specie.Z for i in structure],
dtype='int32').tolist()

def __call__(self, structure: Structure) -> Dict:
return self.convert(structure)
Expand Down Expand Up @@ -224,7 +225,7 @@ def convert(self, structure: Structure, state_attributes: List = None) -> Dict:
if np.size(np.unique(index1)) < len(atoms):
raise RuntimeError("Isolated atoms found in the structure")
else:
return {'atom': np.array(atoms, dtype='int32').tolist(),
return {'atom': atoms,
'bond': bonds,
'state': state_attributes,
'index1': index1,
Expand Down

0 comments on commit 0a6db3f

Please sign in to comment.