Skip to content

Commit

Permalink
test for selection exception
Browse files Browse the repository at this point in the history
  • Loading branch information
sudarshanv01 committed Apr 23, 2024
1 parent 020031e commit 78a38fb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/py4vasp/calculation/_OSZICAR.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,22 @@ def to_dict(self, selection=None):
"""
return_data = {}
if selection is None:
keys_to_include = getattr(self._raw_data, "label")
keys_to_include = self._from_bytes_to_utf(self._raw_data.label)
else:
if keys_to_include not in getattr(self._raw_data, "label"):
labels_as_str = self._from_bytes_to_utf(self._raw_data.label)
if selection not in labels_as_str:
message = """\
Please choose a selection including at least one of the following keywords:
N, E, dE, deps, ncg, rms, rms(c)"""
raise exception.RefinementError(message)
keys_to_include = selection
keys_to_include = [selection]
for key in keys_to_include:
return_data[key.decode("utf-8")] = self._read(key)
return_data[key] = self._read(key)
return return_data

def _from_bytes_to_utf(self, quantity: list):
return [_quantity.decode("utf-8") for _quantity in quantity]

@_base.data_access
def _read(self, key):
# data represents all of the electronic steps for all ionic steps
Expand All @@ -86,7 +90,8 @@ def _read(self, key):
data = [raw.VaspData(_data) for _data in data]
else:
data = [raw.VaspData(data)]
data_index = np.where(self._raw_data.label == key.strip())[0][0]
labels = [label.decode("utf-8") for label in self._raw_data.label]
data_index = labels.index(key)
return_data = [list(_data[:, data_index]) for _data in data]
is_none = [_data.is_none() for _data in data]
if len(return_data) == 1:
Expand Down
9 changes: 9 additions & 0 deletions tests/calculation/test_oszicar.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def test_read(OSZICAR, Assert):
Assert.allclose(actual["rms(c)"], expected.rmsc)


@pytest.mark.parametrize(
"quantity_name", ["N", "E", "dE", "deps", "ncg", "rms", "rms(c)"]
)
def test_read_selection(quantity_name, OSZICAR, Assert):
actual = OSZICAR.read(quantity_name)
expected = getattr(OSZICAR.ref, quantity_name.replace("(", "").replace(")", ""))
Assert.allclose(actual[quantity_name], expected)


def test_plot(OSZICAR, Assert):
graph = OSZICAR.plot()
assert graph.xlabel == "Iteration number"
Expand Down

0 comments on commit 78a38fb

Please sign in to comment.