Skip to content

Commit

Permalink
Fix interface between structure and view (#162)
Browse files Browse the repository at this point in the history
* structure plot should produce one "elements" entry per step
* add check that the number of steps are consistent
* check shapes of arrays is consistent

* prevent numpy 2 installation

---------

Co-authored-by: Martin Schlipf <[email protected]>
  • Loading branch information
sudarshanv01 and martin-schlipf authored Jun 18, 2024
1 parent 3c0be58 commit 3dbac03
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 29 deletions.
32 changes: 16 additions & 16 deletions core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ repository = "https://github.com/vasp-dev/py4vasp"

[tool.poetry.dependencies]
python = ">=3.9"
numpy = ">=1.23"
numpy = "^1.23"
h5py = ">=3.7.0"

[tool.poetry.group.dev.dependencies]
Expand Down
7 changes: 3 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repository = "https://github.com/vasp-dev/py4vasp"

[tool.poetry.dependencies]
python = ">=3.9"
numpy = ">=1.23"
numpy = "^1.23"
h5py = ">=3.7.0"
pandas = ">=1.4.3"
nglview = ">=3.0.5"
Expand Down
33 changes: 33 additions & 0 deletions src/py4vasp/_third_party/view/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class View:
show_axes_at: Sequence[float] = None
"""Defines where the axis is shown, defaults to the origin"""

def __post_init__(self):
self._verify()

def _ipython_display_(self):
widget = self.to_ngl()
widget._ipython_display_()
Expand Down Expand Up @@ -143,6 +146,8 @@ def to_ngl(self):
def _verify(self):
self._raise_error_if_present_on_multiple_steps(self.grid_scalars)
self._raise_error_if_present_on_multiple_steps(self.ion_arrows)
self._raise_error_if_number_steps_inconsistent()
self._raise_error_if_any_shape_is_incorrect()

def _raise_error_if_present_on_multiple_steps(self, attributes):
if not attributes:
Expand All @@ -156,6 +161,34 @@ def _raise_error_if_present_on_multiple_steps(self, attributes):
attribute is supplied with its corresponding grid scalar or ion arrow component."""
)

def _raise_error_if_number_steps_inconsistent(self):
if len(self.elements) == len(self.lattice_vectors) == len(self.positions):
return
raise exception.IncorrectUsage(
"The shape of the arrays is inconsistent. Each of 'elements' (length = "
f"{len(self.elements)}), 'lattice_vectors' (length = "
f"{len(self.lattice_vectors)}), and 'positions' (length = "
f"{len(self.positions)}) should have a leading dimension of the number of"
"steps."
)

def _raise_error_if_any_shape_is_incorrect(self):
number_elements = len(self.elements[0])
_, number_positions, vector_size = np.shape(self.positions)
if number_elements != number_positions:
raise exception.IncorrectUsage(
f"Number of elements ({number_elements}) inconsistent with number of positions ({number_positions})."
)
if vector_size != 3:
raise exception.IncorrectUsage(
f"Positions must have three components and not {vector_size}."
)
cell_shape = np.shape(self.lattice_vectors)[1:]
if any(length != 3 for length in cell_shape):
raise exception.IncorrectUsage(
f"Lattice vectors must be a 3x3 unit cell but have the shape {cell_shape}."
)

def _create_atoms(self, step):
symbols = "".join(self.elements[step])
atoms = ase.Atoms(symbols)
Expand Down
6 changes: 4 additions & 2 deletions src/py4vasp/calculation/_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,12 @@ def to_view(self, supercell=None):
{examples}
"""
make_3d = lambda array: array if array.ndim == 3 else array[np.newaxis]
positions = make_3d(self.positions())
elements = np.tile(self._topology().elements(), (len(positions), 1))
return view.View(
elements=np.atleast_2d(self._topology().elements()),
elements=elements,
lattice_vectors=make_3d(self.lattice_vectors()),
positions=make_3d(self.positions()),
positions=positions,
supercell=self._parse_supercell(supercell),
)

Expand Down
50 changes: 45 additions & 5 deletions tests/third_party/view/test_view.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © VASP Software GmbH,
# Licensed under the Apache License 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

import copy
import io
import itertools
from types import SimpleNamespace
Expand Down Expand Up @@ -103,15 +104,14 @@ def view3d(request, not_core):


@pytest.fixture
def view3d_fail(not_core):
def view_multiple_grid_scalars(not_core):
inputs = base_input_view(is_structure=False)
isosurface = Isosurface(isolevel=0.1, color="#2FB5AB", opacity=0.6)
charge_grid_scalar = GridQuantity(np.random.rand(2, 12, 10, 8), "charge")
potential_grid_scalar = GridQuantity(np.random.rand(2, 12, 10, 8), "potential")
potential_grid_scalar.isosurfaces = [isosurface]
grid_scalars = [charge_grid_scalar, potential_grid_scalar]
view = View(grid_scalars=grid_scalars, **inputs)
return view
return View(**inputs), grid_scalars


@pytest.fixture(params=[True, False])
Expand Down Expand Up @@ -188,9 +188,11 @@ def test_isosurface(view3d):
np.allclose(expected_data, output_data)


def test_fail_isosurface(view3d_fail):
def test_fail_isosurface(view_multiple_grid_scalars):
view, grid_scalars = view_multiple_grid_scalars
with pytest.raises(exception.NotImplemented):
widget = view3d_fail.to_ngl()
view.grid_scalars = grid_scalars
widget = view.to_ngl()


def test_ion_arrows(view_arrow):
Expand Down Expand Up @@ -292,3 +294,41 @@ def test_showaxes_different_origin(is_structure, not_core):
assert msg["args"][1][0][0] == "arrow"
expected_origin = np.array([0.2, 0.2, 0.2]) @ transformation.T
assert np.allclose(msg["args"][1][0][1], expected_origin)


def test_different_number_of_steps_raises_error(view):
too_many_elements = [element for element in view.elements] + [view.elements[0]]
with pytest.raises(exception.IncorrectUsage):
View(too_many_elements, view.lattice_vectors, view.positions)
with pytest.raises(exception.IncorrectUsage):
broken_view = copy.copy(view)
broken_view.elements = too_many_elements
broken_view.to_ngl()
#
too_many_cells = [cell for cell in view.lattice_vectors] + [view.lattice_vectors[0]]
with pytest.raises(exception.IncorrectUsage):
View(view.elements, too_many_cells, view.positions)
with pytest.raises(exception.IncorrectUsage):
broken_view = copy.copy(view)
broken_view.lattice_vectors = too_many_cells
broken_view.to_ngl()
#
too_many_positions = [position for position in view.positions] + [view.positions[0]]
with pytest.raises(exception.IncorrectUsage):
View(view.elements, view.lattice_vectors, too_many_positions)
with pytest.raises(exception.IncorrectUsage):
broken_view = copy.copy(view)
broken_view.positions = too_many_positions
broken_view.to_ngl()


def test_incorrect_shape_raises_error(view):
different_number_atoms = np.zeros((len(view.positions), 7, 3))
with pytest.raises(exception.IncorrectUsage):
View(view.elements, view.lattice_vectors, different_number_atoms)
not_a_three_component_vector = np.array(view.positions)[:, :, :2]
with pytest.raises(exception.IncorrectUsage):
View(view.elements, view.lattice_vectors, not_a_three_component_vector)
incorrect_unit_cell = np.zeros((len(view.lattice_vectors), 2, 4))
with pytest.raises(exception.IncorrectUsage):
View(view.elements, incorrect_unit_cell, view.positions)

0 comments on commit 3dbac03

Please sign in to comment.