Skip to content

Commit

Permalink
io: fix saving/loading of HDiv/HCurl functions on a high-order mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Oct 24, 2024
1 parent fefed5e commit 5546b4a
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 52 deletions.
200 changes: 151 additions & 49 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from firedrake.cython import hdf5interface as h5i
from firedrake.cython import dmcommon
from firedrake.petsc import PETSc, OptionsManager
from firedrake.mesh import MeshTopology, ExtrudedMeshTopology, DEFAULT_MESH_NAME, make_mesh_from_coordinates, DistributedMeshOverlapType
from firedrake.mesh import MeshGeometry, MeshTopology, ExtrudedMeshTopology, DEFAULT_MESH_NAME, make_mesh_from_coordinates, DistributedMeshOverlapType
from firedrake.functionspace import FunctionSpace
from firedrake import functionspaceimpl as impl
from firedrake.functionspacedata import get_global_numbering, create_element
Expand All @@ -20,6 +20,7 @@
import numpy as np
import os
import h5py
from typing import Optional, Union


__all__ = ["DumbCheckpoint", "HDF5File", "FILE_READ", "FILE_CREATE", "FILE_UPDATE", "CheckpointFile"]
Expand Down Expand Up @@ -896,25 +897,47 @@ def _save_function_space_topology(self, tV):
topology_dm.setName(base_tmesh_name)

@PETSc.Log.EventDecorator("SaveFunction")
def save_function(self, f, idx=None, name=None, timestepping_info={}):
r"""Save a :class:`~.Function`.
def save_function(
self,
f: Function,
idx: Optional[int] = None,
name: Optional[str] = None,
timestepping_info: Optional[dict] = {},
affine_coordinates: Optional[Union[MeshGeometry, Function]] = None,
affine_quadrature_degree: Optional[int] = None,
) -> None:
"""Save a :class:`~.Function`.
:arg f: the :class:`~.Function` to save.
:kwarg idx: optional timestepping index. A function can
Parameters
----------
f
`Function` to save.
idx
Optional timestepping index. A function can
either be saved in timestepping mode or in normal
mode (non-timestepping); for each function of interest,
this method must always be called with the idx parameter
set or never be called with the idx parameter set.
:kwarg name: optional alternative name to save the function under.
:kwarg timestepping_info: optional (requires idx) additional information
name
Optional alternative name to save the function under.
timestepping_info
Optional (requires idx) additional information
such as time, timestepping that can be stored along a function for
each index.
affine_coordinates
Representation of a fictitious affine mesh onto which
the function is mapped before saving; only significant for
HDiv/HCurl functions defined on high-order mesh.
affine_quadrature_degree
Quadrature degree to be used when mapping onto the affine mesh;
only significant for HDiv/HCurl functions defined on high-order mesh.
"""
V = f.function_space()
mesh = V.mesh()
if name:
g = Function(V, val=f.dat, name=name)
return self.save_function(g, idx=idx, timestepping_info=timestepping_info)
return self.save_function(g, idx=idx, timestepping_info=timestepping_info, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
# -- Save function space --
self._save_function_space(V)
# -- Save function --
Expand All @@ -926,7 +949,7 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
path = os.path.join(base_path, str(i))
self.require_group(path)
self.set_attr(path, PREFIX + "_function", fsub.name())
self.save_function(fsub, idx=idx, timestepping_info=timestepping_info)
self.save_function(fsub, idx=idx, timestepping_info=timestepping_info, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
self._update_mixed_function_name_mixed_function_space_name_map(mesh.name, {f.name(): V_name})
else:
tf = f.topological
Expand All @@ -940,10 +963,32 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
path = self._path_to_function_embedded(tmesh.name, mesh.name, V_name, f.name())
self.require_group(path)
method = get_embedding_method_for_checkpointing(element)
_V = FunctionSpace(mesh, _element)
if mesh.coordinates.function_space().ufl_element().embedded_subdegree > 1:
# Handle non-affine mesh; this is only relevant when embedding into a DG space.
if affine_coordinates is None:
raise ValueError("Must provide affine_coordinates to save functions on high-order mesh")
if affine_quadrature_degree is None:
raise ValueError("Must provide affine_quadrature_degree to save functions on high-order mesh")
if isinstance(affine_coordinates, MeshGeometry):
affine_coordinates = affine_coordinates.coordinates
else:
if not isinstance(affine_coordinates, Function):
raise ValueError("affine_coordinates must be {MeshGeometry, Function}")
if affine_coordinates.function_space().mesh().topology is not tmesh:
raise ValueError(f"affine_coordinates.function_space().mesh().topology ({affine_coordinates.function_space().mesh().topology}) is not f.mesh().topology ({tmesh})")
if affine_coordinates.function_space().mesh() is not mesh:
affine_coordinate_V = FunctionSpace(mesh, affine_coordinates.function_space().ufl_element())
affine_coordinates = Function(affine_coordinate_V, val=affine_coordinates.topological)
if affine_coordinates.topological.name() == mesh.coordinates.topological.name():
raise ValueError(f"affine_coordinate.name() ({affine_coordinates.name()}) == mesh.coordinates.topological.name() ({mesh.coordinates.topological.name()})")
self._save_ufl_element(path, PREFIX_EMBEDDED + "_affine_coordinate_element", affine_coordinates.topological.function_space().ufl_element())
self.set_attr(path, PREFIX_EMBEDDED + "_affine_coordinates", affine_coordinates.topological.name())
self.set_attr(path, PREFIX_EMBEDDED + "_affine_quadrature_degree", affine_quadrature_degree)
self._save_function_topology(affine_coordinates.topological)
_name = "_".join([PREFIX_EMBEDDED, f.name()])
_V = FunctionSpace(mesh, _element)
_f = Function(_V, name=_name)
self._project_function_for_checkpointing(_f, f, method)
self._project_function_for_checkpointing(_f, f, method, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
self.save_function(_f, idx=idx, timestepping_info=timestepping_info)
self.set_attr(path, PREFIX_EMBEDDED + "_function", _name)
else:
Expand Down Expand Up @@ -1045,35 +1090,41 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
path = self._path_to_topology_extruded(tmesh_name)
if path in self.h5pyfile:
# -- Load mesh topology --
base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters)
base_tmesh.init()
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
if variable_layers:
cell = base_tmesh.ufl_cell()
element = finat.ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2)
_ = self._load_function_space_topology(base_tmesh, element)
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
nroots, _, _ = lsf.getGraph()
layers_a = np.empty(nroots, dtype=utils.IntType)
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
self.viewer.pushGroup(path)
layers_a_iset.load(self.viewer)
self.viewer.popGroup()
layers_a = layers_a_iset.getIndices()
layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType)
unit = MPI._typedict[np.dtype(utils.IntType).char]
lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE)
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
if topology is None:
base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters)
base_tmesh.init()
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
if variable_layers:
cell = base_tmesh.ufl_cell()
element = finat.ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2)
_ = self._load_function_space_topology(base_tmesh, element)
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
nroots, _, _ = lsf.getGraph()
layers_a = np.empty(nroots, dtype=utils.IntType)
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
self.viewer.pushGroup(path)
layers_a_iset.load(self.viewer)
self.viewer.popGroup()
layers_a = layers_a_iset.getIndices()
layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType)
unit = MPI._typedict[np.dtype(utils.IntType).char]
lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE)
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
else:
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
else:
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
if topology.name != tmesh_name:
raise RuntimeError(f"Got wrong mesh topology (f{topology.name}): expecting f{tmesh_name}")
tmesh = topology
base_tmesh = topology._base_mesh
# -- Load mesh --
path = self._path_to_mesh(tmesh_name, name)
coord_element = self._load_ufl_element(path, PREFIX + "_coordinate_element")
Expand Down Expand Up @@ -1301,14 +1352,29 @@ def _load_function_space_topology(self, tmesh, element):
return impl.FunctionSpace(tmesh, element)

@PETSc.Log.EventDecorator("LoadFunction")
def load_function(self, mesh, name, idx=None):
r"""Load a :class:`~.Function` defined on `mesh`.
def load_function(
self,
mesh: MeshGeometry,
name: str,
idx: Optional[int] = None
) -> Function:
"""Load a :class:`~.Function` defined on ``mesh``.
:arg mesh: the mesh on which the function is defined.
:arg name: the name of the :class:`~.Function` to load.
:kwarg idx: optional timestepping index. A function can
Parameters
----------
mesh
mesh on which the function is defined.
name
name of the `Function` to load.
idx
Optional timestepping index. A function can
be loaded with idx only when it was saved with idx.
:returns: the loaded :class:`~.Function`.
Returns
-------
Function
Loaded `Function`.
"""
tmesh = mesh.topology
if name in self._get_mixed_function_name_mixed_function_space_name_map(mesh.name):
Expand Down Expand Up @@ -1341,7 +1407,19 @@ def load_function(self, mesh, name, idx=None):
method = get_embedding_method_for_checkpointing(element)
assert _element == _f.function_space().ufl_element()
f = Function(V, name=name)
self._project_function_for_checkpointing(f, _f, method)
if mesh.coordinates.function_space().ufl_element().embedded_subdegree > 1 and \
self.has_attr(path, PREFIX_EMBEDDED + "_affine_coordinates"):
# Handle non-affine mesh; this is only relevant when embedding into a DG space.
affine_coord_element = self._load_ufl_element(path, PREFIX_EMBEDDED + "_affine_coordinate_element")
affine_coord_name = self.get_attr(path, PREFIX_EMBEDDED + "_affine_coordinates")
affine_quadrature_degree = self.get_attr(path, PREFIX_EMBEDDED + "_affine_quadrature_degree")
affine_coordinates = self._load_function_topology(tmesh, affine_coord_element, affine_coord_name)
affine_coordinate_V = FunctionSpace(mesh, affine_coordinates.function_space().ufl_element())
affine_coordinates = Function(affine_coordinate_V, val=affine_coordinates.topological)
else:
affine_coordinates = None
affine_quadrature_degree = None
self._project_function_for_checkpointing(f, _f, method, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
return f
else:
tf_name = self.get_attr(path, PREFIX + "_vec")
Expand Down Expand Up @@ -1638,11 +1716,35 @@ def _is_mixed_function_space(self, mesh_name, V_name):
return True
return False

def _project_function_for_checkpointing(self, f, _f, method):
def _project_function_for_checkpointing(self, target, source, method, affine_coordinates=None, affine_quadrature_degree=None):
if affine_coordinates:
if affine_quadrature_degree is None:
raise ValueError("Need affine_quadrature_degree to save/load HDiv/HCurl functions on high-order mesh")
# Need to map to/from the representation on a fictitious
# affine mesh represented by affine_coordinates.
K = firedrake.grad(affine_coordinates) # K = (\partial X /\partial x) = F^-1.
from_elem = source.function_space().ufl_element()
to_elem = target.function_space().ufl_element()
if to_elem.mapping() == "identity":
if from_elem.mapping() == "covariant Piola":
source = firedrake.transpose(firedrake.inv(K)) * source
elif from_elem.mapping() == "contravariant Piola":
source = 1. / firedrake.det(K) * K * source
else:
raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})")
elif from_elem.mapping() == "identity":
if to_elem.mapping() == "covariant Piola":
source = firedrake.transpose(K) * source
elif to_elem.mapping() == "contravariant Piola":
source = firedrake.det(K) * firedrake.inv(K) * source
else:
raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})")
else:
raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})")
if method == "project":
getattr(f, method)(_f, solver_parameters={"ksp_rtol": 1.e-16})
getattr(target, method)(source, solver_parameters={"ksp_rtol": 1.e-16}, quadrature_degree=affine_quadrature_degree)
elif method == "interpolate":
getattr(f, method)(_f)
getattr(target, method)(source)
else:
raise ValueError(f"Unknown method for projecting: {method}")

Expand Down
Loading

0 comments on commit 5546b4a

Please sign in to comment.