Skip to content

Commit

Permalink
Merge branch 'main' into fix/cli-workflows-pw-base
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber authored Sep 11, 2024
2 parents c595d7b + b79189d commit f28141b
Show file tree
Hide file tree
Showing 15 changed files with 369 additions and 25 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ aiida-quantumespresso = 'aiida_quantumespresso.cli:cmd_root'
'quantumespresso.seekpath_structure_analysis' = 'aiida_quantumespresso.calculations.functions.seekpath_structure_analysis:seekpath_structure_analysis'
'quantumespresso.xspectra' = 'aiida_quantumespresso.calculations.xspectra:XspectraCalculation'
'quantumespresso.open_grid' = 'aiida_quantumespresso.calculations.open_grid:OpenGridCalculation'
'quantumespresso.bands' = 'aiida_quantumespresso.calculations.bands:BandsCalculation'

[project.entry-points.'aiida.data']
'quantumespresso.force_constants' = 'aiida_quantumespresso.data.force_constants:ForceConstantsData'
Expand All @@ -105,6 +106,7 @@ aiida-quantumespresso = 'aiida_quantumespresso.cli:cmd_root'
'quantumespresso.pw2wannier90' = 'aiida_quantumespresso.parsers.pw2wannier90:Pw2wannier90Parser'
'quantumespresso.xspectra' = 'aiida_quantumespresso.parsers.xspectra:XspectraParser'
'quantumespresso.open_grid' = 'aiida_quantumespresso.parsers.open_grid:OpenGridParser'
'quantumespresso.bands' = 'aiida_quantumespresso.parsers.bands:BandsParser'

[project.entry-points.'aiida.tools.calculations']
'quantumespresso.pw' = 'aiida_quantumespresso.tools.calculations.pw:PwCalculationTools'
Expand All @@ -125,6 +127,7 @@ aiida-quantumespresso = 'aiida_quantumespresso.cli:cmd_root'
'quantumespresso.xps' = 'aiida_quantumespresso.workflows.xps:XpsWorkChain'
'quantumespresso.xspectra.core' = 'aiida_quantumespresso.workflows.xspectra.core:XspectraCoreWorkChain'
'quantumespresso.xspectra.crystal' = 'aiida_quantumespresso.workflows.xspectra.crystal:XspectraCrystalWorkChain'
'quantumespresso.bands.base' = 'aiida_quantumespresso.workflows.bands.base:BandsBaseWorkChain'

[tool.flit.module]
name = 'aiida_quantumespresso'
Expand Down
39 changes: 39 additions & 0 deletions src/aiida_quantumespresso/calculations/bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
"""`CalcJob` implementation for the bands.x code of Quantum ESPRESSO."""

from aiida import orm

from aiida_quantumespresso.calculations.namelists import NamelistsCalculation


class BandsCalculation(NamelistsCalculation):
"""`CalcJob` implementation for the bands.x code of Quantum ESPRESSO.
bands.x code of the Quantum ESPRESSO distribution, re-orders bands, and computes band-related properties.
It computes for instance the expectation value of the momentum operator:
<Psi(n,k) | i * m * [H, x] | Psi(m,k)>. For more information, refer to http://www.quantum-espresso.org/
"""

_MOMENTUM_OPERATOR_NAME = 'momentum_operator.dat'
_BANDS_NAME = 'bands.dat'

_default_namelists = ['BANDS']
_blocked_keywords = [
('BANDS', 'outdir', NamelistsCalculation._OUTPUT_SUBFOLDER), # pylint: disable=protected-access
('BANDS', 'prefix', NamelistsCalculation._PREFIX), # pylint: disable=protected-access
('BANDS', 'filband', _BANDS_NAME),
('BANDS', 'filp', _MOMENTUM_OPERATOR_NAME), # Momentum operator
]

_internal_retrieve_list = []
_default_parser = 'quantumespresso.bands'

@classmethod
def define(cls, spec):
"""Define the process specification."""
# yapf: disable
super().define(spec)
spec.input('parent_folder', valid_type=(orm.RemoteData, orm.FolderData), required=True)
spec.output('output_parameters', valid_type=orm.Dict)
# yapf: enable
27 changes: 27 additions & 0 deletions src/aiida_quantumespresso/parsers/bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
from aiida.orm import Dict

from aiida_quantumespresso.utils.mapping import get_logging_container

from .base import BaseParser


class BandsParser(BaseParser):
"""``Parser`` implementation for the ``BandsCalculation`` calculation job class."""

def parse(self, **kwargs):
"""Parse the retrieved files of a ``BandsCalculation`` into output nodes."""
logs = get_logging_container()

_, parsed_data, logs = self.parse_stdout_from_retrieved(logs)

base_exit_code = self.check_base_errors(logs)
if base_exit_code:
return self.exit(base_exit_code, logs)

self.out('output_parameters', Dict(parsed_data))

if 'ERROR_OUTPUT_STDOUT_INCOMPLETE'in logs.error:
return self.exit(self.exit_codes.ERROR_OUTPUT_STDOUT_INCOMPLETE, logs)

return self.exit(logs=logs)
Empty file.
61 changes: 61 additions & 0 deletions src/aiida_quantumespresso/workflows/bands/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""Workchain to run a Quantum ESPRESSO bands.x calculation with automated error handling and restarts."""
from aiida.common import AttributeDict
from aiida.engine import BaseRestartWorkChain, ProcessHandlerReport, process_handler, while_
from aiida.plugins import CalculationFactory

BandsCalculation = CalculationFactory('quantumespresso.bands')


class BandsBaseWorkChain(BaseRestartWorkChain):
"""Workchain to run a Quantum ESPRESSO bands.x calculation with automated error handling and restarts."""

_process_class = BandsCalculation

@classmethod
def define(cls, spec):
"""Define the process specification."""
# yapf: disable
super().define(spec)
spec.expose_inputs(BandsCalculation, namespace='bands')
spec.expose_outputs(BandsCalculation)
spec.outline(
cls.setup,
while_(cls.should_run_process)(
cls.run_process,
cls.inspect_process,
),
cls.results,
)
spec.exit_code(300, 'ERROR_UNRECOVERABLE_FAILURE',
message='The calculation failed with an unrecoverable error.')
# yapf: enable

def setup(self):
"""Call the `setup` of the `BaseRestartWorkChain` and then create the inputs dictionary in `self.ctx.inputs`.
This `self.ctx.inputs` dictionary will be used by the `BaseRestartWorkChain` to submit the calculations in the
internal loop.
"""
super().setup()
self.ctx.restart_calc = None
self.ctx.inputs = AttributeDict(self.exposed_inputs(BandsCalculation, 'bands'))

def report_error_handled(self, calculation, action):
"""Report an action taken for a calculation that has failed.
This should be called in a registered error handler if its condition is met and an action was taken.
:param calculation: the failed calculation node
:param action: a string message with the action taken
"""
arguments = [calculation.process_label, calculation.pk, calculation.exit_status, calculation.exit_message]
self.report('{}<{}> failed with exit status {}: {}'.format(*arguments))
self.report(f'Action taken: {action}')

@process_handler(priority=600)
def handle_unrecoverable_failure(self, node):
"""Handle calculations with an exit status below 400 which are unrecoverable, so abort the work chain."""
if node.is_failed and node.exit_status < 400:
self.report_error_handled(node, 'unrecoverable error, aborting...')
return ProcessHandlerReport(True, self.exit_codes.ERROR_UNRECOVERABLE_FAILURE)
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st
new_supercell = get_supercell_result['new_supercell']
output_params['supercell_factors'] = multiples

result['supercell'] = new_supercell
output_params['supercell_num_sites'] = len(new_supercell.sites)
output_params['supercell_cell_matrix'] = new_supercell.cell
output_params['supercell_cell_lengths'] = new_supercell.cell_lengths
Expand Down
105 changes: 81 additions & 24 deletions src/aiida_quantumespresso/workflows/xspectra/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Uses QuantumESPRESSO pw.x and xspectra.x.
"""
from aiida import orm
from aiida.common import AttributeDict, ValidationError
from aiida.common import AttributeDict
from aiida.engine import ToContext, WorkChain, if_
from aiida.orm import UpfData as aiida_core_upf
from aiida.plugins import CalculationFactory, DataFactory, WorkflowFactory
Expand Down Expand Up @@ -173,6 +173,19 @@ def define(cls, spec):
help=('Input namespace to provide core wavefunction inputs for each element. Must follow the format: '
'``core_wfc_data__{symbol} = {node}``')
)
spec.input_namespace(
'symmetry_data',
valid_type=(orm.Dict, orm.Int),
dynamic=True,
required=False,
help=(
'Input namespace to define equivalent sites and spacegroup number for the system. If defined, will '
'skip symmetry analysis and structure standardization. Use *only* if symmetry data are known '
'for certain. Requires ``spacegroup_number`` (Int) and ``equivalent_sites_data`` (Dict) to be '
'defined separately. All keys in `equivalent_sites_data` must be formatted as "site_<site_index>". '
'See docstring of `get_xspectra_structures` for more information about inputs.'
)
)
spec.inputs.validator = cls.validate_inputs
spec.outline(
cls.setup,
Expand Down Expand Up @@ -370,7 +383,7 @@ def get_builder_from_protocol( # pylint: disable=too-many-statements


@staticmethod
def validate_inputs(inputs, _):
def validate_inputs(inputs, _): # pylint: disable=too-many-return-statements
"""Validate the inputs before launching the WorkChain."""
structure = inputs['structure']
kinds_present = [kind.name for kind in structure.kinds]
Expand All @@ -382,54 +395,92 @@ def validate_inputs(inputs, _):
if element not in elements_present:
extra_elements.append(element)
if len(extra_elements) > 0:
raise ValidationError(
return (
f'Some elements in ``elements_list`` {extra_elements} do not exist in the'
f' structure provided {elements_present}.'
)

abs_atom_marker = inputs['abs_atom_marker'].value
if abs_atom_marker in kinds_present:
raise ValidationError(
return (
f'The marker given for the absorbing atom ("{abs_atom_marker}") matches an existing Kind in the '
f'input structure ({kinds_present}).'
)

if not inputs['core']['get_powder_spectrum'].value:
raise ValidationError(
return (
'The ``get_powder_spectrum`` input for the XspectraCoreWorkChain namespace must be ``True``.'
)

if 'upf2plotcore_code' not in inputs and 'core_wfc_data' not in inputs:
raise ValidationError(
return (
'Neither a ``Code`` node for upf2plotcore.sh or a set of ``core_wfc_data`` were provided.'
)

if 'core_wfc_data' in inputs:
core_wfc_data_list = sorted(inputs['core_wfc_data'].keys())
if core_wfc_data_list != absorbing_elements_list:
raise ValidationError(
return (
f'The ``core_wfc_data`` provided ({core_wfc_data_list}) does not match the list of'
f' absorbing elements ({absorbing_elements_list})'
)
else:
empty_core_wfc_data = []
for key, value in inputs['core_wfc_data'].items():
header_line = value.get_content()[:40]
try:
num_core_states = int(header_line.split(' ')[5])
except Exception as exc:
raise ValidationError(
'The core wavefunction data file is not of the correct format'
) from exc
if num_core_states == 0:
empty_core_wfc_data.append(key)
if len(empty_core_wfc_data) > 0:
raise ValidationError(
f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain '
'any wavefunction data.'
)
empty_core_wfc_data = []
for key, value in inputs['core_wfc_data'].items():
header_line = value.get_content()[:40]
try:
num_core_states = int(header_line.split(' ')[5])
except: # pylint: disable=bare-except
return (
'The core wavefunction data file is not of the correct format'
) # pylint: enable=bare-except
if num_core_states == 0:
empty_core_wfc_data.append(key)
if len(empty_core_wfc_data) > 0:
return (
f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain '
'any wavefunction data.'
)

if 'symmetry_data' in inputs:
spacegroup_number = inputs['symmetry_data']['spacegroup_number'].value
equivalent_sites_data = inputs['symmetry_data']['equivalent_sites_data'].get_dict()
if spacegroup_number <= 0 or spacegroup_number >= 231:
return (
f'Input spacegroup number ({spacegroup_number}) outside of valid range (1-230).'
)

input_elements = []
required_keys = sorted(['symbol', 'multiplicity', 'kind_name', 'site_index'])
invalid_entries = []
# We check three things here: (1) are there any site indices which are outside of the possible
# range of site indices (2) do we have all the required keys for each entry,
# and (3) is there a mismatch between `absorbing_elements_list` and the elements specified
# in the entries of `equivalent_sites_data`. These checks are intended only to avoid a crash.
# We assume otherwise that the user knows what they're doing and has set everything else
# to their preferences correctly.
for site_label, value in equivalent_sites_data.items():
if not set(required_keys).issubset(set(value.keys())) :
invalid_entries.append(site_label)
elif value['symbol'] not in input_elements:
input_elements.append(value['symbol'])
if value['site_index'] < 0 or value['site_index'] >= len(structure.sites):
return (
f'The site index for {site_label} ({value["site_index"]}) is outside the range of '
+ f'sites within the structure (0-{len(structure.sites) -1}).'
)

if len(invalid_entries) != 0:
return (
f'The required keys ({required_keys}) were not found in the following entries: {invalid_entries}'
)

sorted_input_elements = sorted(input_elements)
if sorted_input_elements != absorbing_elements_list:
return (f'Elements defined for sites in `equivalent_sites_data` ({sorted_input_elements}) '
f'do not match the list of absorbing elements ({absorbing_elements_list})')


# pylint: enable=too-many-return-statements
def setup(self):
"""Set required context variables."""
if 'core_wfc_data' in self.inputs.keys():
Expand Down Expand Up @@ -489,6 +540,12 @@ def get_xspectra_structures(self):
if 'spglib_settings' in self.inputs:
inputs['spglib_settings'] = self.inputs.spglib_settings

if 'symmetry_data' in self.inputs:
inputs['parse_symmetry'] = orm.Bool(False)
input_sym_data = self.inputs.symmetry_data
inputs['equivalent_sites_data'] = input_sym_data['equivalent_sites_data']
inputs['spacegroup_number'] = input_sym_data['spacegroup_number']

if 'relax' in self.inputs:
result = get_xspectra_structures(self.ctx.optimized_structure, **inputs)
else:
Expand Down
27 changes: 27 additions & 0 deletions tests/calculations/test_bands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""Tests for the `BandsCalculation` class."""
# pylint: disable=protected-access
from aiida.common import datastructures

from aiida_quantumespresso.calculations.bands import BandsCalculation


def test_bands_default(fixture_sandbox, generate_calc_job, generate_inputs_bands, file_regression):
"""Test a default `BandsCalculation`."""
entry_point_name = 'quantumespresso.bands'

inputs = generate_inputs_bands()
calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)

retrieve_list = [BandsCalculation._DEFAULT_OUTPUT_FILE] + BandsCalculation._internal_retrieve_list

# Check the attributes of the returned `CalcInfo`
assert isinstance(calc_info, datastructures.CalcInfo)
assert sorted(calc_info.retrieve_list) == sorted(retrieve_list)

with fixture_sandbox.open('aiida.in') as handle:
input_written = handle.read()

# Checks on the files written to the sandbox folder as raw input
assert sorted(fixture_sandbox.get_content_list()) == sorted(['aiida.in'])
file_regression.check(input_written, encoding='utf-8', extension='.in')
6 changes: 6 additions & 0 deletions tests/calculations/test_bands/test_bands_default.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
&BANDS
filband = 'bands.dat'
filp = 'momentum_operator.dat'
outdir = './out/'
prefix = 'aiida'
/
23 changes: 22 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# pylint: disable=redefined-outer-name,too-many-statements
# pylint: disable=redefined-outer-name,too-many-statements,too-many-lines
"""Initialise a text database and profile for pytest."""
from collections.abc import Mapping
import io
Expand Down Expand Up @@ -594,6 +594,27 @@ def _generate_inputs_q2r():
return _generate_inputs_q2r


@pytest.fixture
def generate_inputs_bands(fixture_sandbox, fixture_localhost, fixture_code, generate_remote_data):
"""Generate default inputs for a `BandsCalculation."""

def _generate_inputs_bands():
"""Generate default inputs for a `BandsCalculation."""
from aiida_quantumespresso.utils.resources import get_default_options

inputs = {
'code': fixture_code('quantumespresso.bands'),
'parent_folder': generate_remote_data(fixture_localhost, fixture_sandbox.abspath, 'quantumespresso.pw'),
'metadata': {
'options': get_default_options()
}
}

return inputs

return _generate_inputs_bands


@pytest.fixture
def generate_inputs_ph(
generate_calc_job_node, generate_structure, fixture_localhost, fixture_code, generate_kpoints_mesh
Expand Down
Loading

0 comments on commit f28141b

Please sign in to comment.