Skip to content

Commit

Permalink
move more from shim to jinja
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Oct 28, 2024
1 parent 21a55b7 commit a85ba7e
Show file tree
Hide file tree
Showing 20 changed files with 875 additions and 927 deletions.
42 changes: 24 additions & 18 deletions autotest/test_codegen.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest

from autotest.conftest import get_project_root_path
from flopy.mf6.utils.codegen import make_all, make_targets
from flopy.mf6.utils.codegen.context import Context
from flopy.mf6.utils.codegen.dfn import Dfn
from flopy.mf6.utils.codegen.make import make_all, make_targets

PROJ_ROOT = get_project_root_path()
MF6_PATH = PROJ_ROOT / "flopy" / "mf6"
Expand All @@ -18,14 +18,19 @@
@pytest.mark.parametrize("dfn_name", DFN_NAMES)
def test_dfn_load(dfn_name):
dfn_path = DFN_PATH / f"{dfn_name}.dfn"

common_path = DFN_PATH / "common.dfn"
with open(common_path, "r") as f:
common, _ = Dfn._load(f)

with open(dfn_path, "r") as f:
dfn = Dfn.load(f, name=Dfn.Name(*dfn_name.split("-")), common=common)
if dfn_name in ["sln-ems", "exg-gwfprt", "exg-gwfgwe", "exg-gwfgwt"]:
with open(common_path, "r") as common_file, open(
dfn_path, "r"
) as dfn_file:
name = Dfn.Name.parse(dfn_name)
common, _ = Dfn._load(common_file)
dfn = Dfn.load(dfn_file, name=name, common=common)
if name in [
("sln", "ems"),
("exg", "gwfprt"),
("exg", "gwfgwe"),
("exg", "gwfgwt"),
]:
assert not any(dfn)
else:
assert any(dfn)
Expand All @@ -34,17 +39,18 @@ def test_dfn_load(dfn_name):
@pytest.mark.parametrize("dfn_name", DFN_NAMES)
def test_make_targets(dfn_name, function_tmpdir):
common_path = DFN_PATH / "common.dfn"
with open(common_path, "r") as f:
common, _ = Dfn._load(f)

with open(DFN_PATH / f"{dfn_name}.dfn", "r") as f:
dfn = Dfn.load(f, name=Dfn.Name(*dfn_name.split("-")), common=common)

with open(common_path, "r") as common_file, open(
DFN_PATH / f"{dfn_name}.dfn", "r"
) as dfn_file:
name = Dfn.Name.parse(dfn_name)
common, _ = Dfn._load(common_file)
dfn = Dfn.load(dfn_file, name=name, common=common)

target_names = Context.Name.from_dfn(dfn)
make_targets(dfn, function_tmpdir, verbose=True)

for name in Context.Name.from_dfn(dfn):
source_path = function_tmpdir / name.target
assert source_path.is_file()
assert all(
(function_tmpdir / name.target).is_file() for name in target_names
)


def test_make_all(function_tmpdir):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@
from jinja2 import Environment, PackageLoader

from flopy.mf6.utils.codegen.context import Context
from flopy.mf6.utils.codegen.dfn import Dfn, Dfns
from flopy.mf6.utils.codegen.ref import Ref, Refs
from flopy.mf6.utils.codegen.dfn import Dfn, Dfns, Ref, Refs

__all__ = ["make_targets", "make_all"]

_TEMPLATE_LOADER = PackageLoader("flopy", "mf6/utils/codegen/templates/")
_TEMPLATE_ENV = Environment(loader=_TEMPLATE_LOADER)
_TEMPLATE_NAME = "context.py.jinja"
_TEMPLATE = _TEMPLATE_ENV.get_template(_TEMPLATE_NAME)


def make_targets(dfn: Dfn, outdir: Path, verbose: bool = False):
"""Generate Python source file(s) from the given input definition."""

for context in Context.from_dfn(dfn):
target = outdir / context.name.target
name = context.name
target = outdir / name.target
template = _TEMPLATE_ENV.get_template(name.template)
with open(target, "w") as f:
f.write(_TEMPLATE.render(**context.render()))
f.write(template.render(**context.render()))
if verbose:
print(f"Wrote {target}")

Expand Down
47 changes: 19 additions & 28 deletions flopy/mf6/utils/codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
Optional,
)

from flopy.mf6.utils.codegen.dfn import Dfn, Vars
from flopy.mf6.utils.codegen.ref import Ref
from flopy.mf6.utils.codegen.render import renderable
from flopy.mf6.utils.codegen.dfn import Dfn, Ref, Vars
from flopy.mf6.utils.codegen.renderable import renderable
from flopy.mf6.utils.codegen.shim import SHIM


Expand Down Expand Up @@ -99,6 +98,18 @@ def target(self) -> str:
"""The source file name to generate."""
return f"mf{self.title}.py"

@property
def template(self) -> str:
"""The template file to use."""
if self.base == "MFSimulationBase":
return "simulation.py.jinja"
elif self.base == "MFModel":
return "model.py.jinja"
elif self.base == "MFPackage":
if self.l == "exg":
return "exchange.py.jinja"
return "package.py.jinja"

@property
def description(self) -> str:
"""A description of the input context."""
Expand All @@ -109,29 +120,11 @@ def description(self) -> str:
elif self.base == "MFModel":
return f"Modflow{title} defines a {l.upper()} model."
elif self.base == "MFSimulationBase":
return """
MFSimulation is used to load, build, and/or save a MODFLOW 6 simulation.
A MFSimulation object must be created before creating any of the MODFLOW 6
model objects."""

def parent(self, ref: Optional[Ref] = None) -> Optional[str]:
"""
Return the name of the parent `__init__` method parameter,
or `None` if the context cannot have parents. Contexts can
have more than one possible parent, in which case the name
of the parameter is of the pattern `name1_or_..._or_nameN`.
"""
if ref:
return ref.parent
if self == ("sim", "nam"):
return None
elif (
self.l is None
or self.r is None
or self.l in ["sim", "exg", "sln"]
):
return "simulation"
return "model"
return (
"MFSimulation is used to load, build, and/or save a MODFLOW 6 simulation."
" A MFSimulation object must be created before creating any of the MODFLOW"
" 6 model objects."
)

@staticmethod
def from_dfn(dfn: Dfn) -> List["Context.Name"]:
Expand Down Expand Up @@ -172,7 +165,6 @@ def from_dfn(dfn: Dfn) -> List["Context.Name"]:
name: Name
vars: Vars
base: Optional[type] = None
parent: Optional[str] = None
description: Optional[str] = None
meta: Optional[Dict[str, Any]] = None

Expand All @@ -194,7 +186,6 @@ def from_dfn(cls, dfn: Dfn) -> Iterator["Context"]:
name=name,
vars=dfn.data,
base=name.base,
parent=name.parent(ref),
description=name.description,
meta=meta,
)
155 changes: 125 additions & 30 deletions flopy/mf6/utils/codegen/dfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from boltons.dictutils import OMD

from flopy.mf6.utils.codegen.utils import try_literal_eval, try_parse_bool

_SCALARS = {
"keyword",
"integer",
Expand All @@ -27,30 +29,7 @@

Vars = Dict[str, "Var"]
Dfns = Dict[str, "Dfn"]


def _try_parse_bool(value: Any) -> Any:
"""
Try to parse a boolean from a string as represented
in a DFN file, otherwise return the value unaltered.
"""

if isinstance(value, str):
value = value.lower()
if value in ["true", "false"]:
return value == "true"
return value


def _try_literal_eval(value: str) -> Any:
"""
Try to parse a string as a literal. If this fails,
return the value unaltered.
"""
try:
return literal_eval(value)
except (SyntaxError, ValueError):
return value
Refs = Dict[str, "Ref"]


@dataclass
Expand Down Expand Up @@ -87,13 +66,20 @@ class Dfn(UserDict):

class Name(NamedTuple):
"""
Uniquely identifies an input definition. A name
consists of a left term and optional right term.
Uniquely identifies an input definition.
Consists of a left term and a right term.
"""

l: str
r: str

@classmethod
def parse(cls, v: str) -> "Dfn.Name":
try:
return cls(*v.split("-"))
except:
raise ValueError(f"Bad DFN name format: {v}")

name: Optional[Name]
meta: Optional[Dict[str, Any]]

Expand Down Expand Up @@ -232,7 +218,7 @@ def _map(spec: Dict[str, Any]) -> Var:
# stay a string except default values, which we'll
# try to parse as arbitrary literals below, and at
# some point types, once we introduce type hinting
spec = {k: _try_parse_bool(v) for k, v in spec.items()}
spec = {k: try_parse_bool(v) for k, v in spec.items()}

# pull off attributes we're interested in
_name = spec["name"]
Expand Down Expand Up @@ -406,11 +392,12 @@ def _is_implicit_scalar_record():
block=block,
description=description,
default=(
_try_literal_eval(default)
if _type != "string"
else default
try_literal_eval(default) if _type != "string" else default
),
children=children,
# type is a string for now, when
# introducing type hints make it
# a proper type...
meta={"ref": ref, "type": type_},
)

Expand Down Expand Up @@ -440,3 +427,111 @@ def _is_implicit_scalar_record():
"refs": referenced,
},
)


@dataclass
class Ref:
"""
A foreign-key-like reference between a file input variable
and another input definition. This allows an input context
to refer to another input context, by including a filepath
variable whose name acts as a foreign key for a different
input context. The referring context's `__init__` method
is modified such that the variable named `val` replaces
the `key` variable.
Notes
-----
This class is used to represent subpackage references.
Parameters
----------
key : str
The name of the foreign key file input variable.
val : str
The name of the data variable in the referenced context.
abbr : str
An abbreviation of the referenced context's name.
param : str
The referenced parameter name.
parents : List[str]
The referenced context's supported parents.
description : Optional[str]
The reference's description.
"""

key: str
val: str
abbr: str
param: str
parent: str
description: Optional[str]

@classmethod
def from_dfn(cls, dfn: Dfn) -> Optional["Ref"]:
"""
Try to load a reference from the definition.
Returns `None` if the definition cannot be
referenced by other contexts.
"""

# TODO: all this won't be necessary once we
# structure DFN format; we can then support
# subpackage references directly instead of
# by making assumptions about `dfn.meta`

if not dfn.meta or "dfn" not in dfn.meta:
return None

_, meta = dfn.meta["dfn"]

lines = {
"subpkg": next(
iter(
m
for m in meta
if isinstance(m, str) and m.startswith("subpac")
),
None,
),
"parent": next(
iter(
m
for m in meta
if isinstance(m, str) and m.startswith("parent")
),
None,
),
}

def _subpkg():
line = lines["subpkg"]
_, key, abbr, param, val = line.split()
matches = [v for v in dfn.values() if v.name == val]
if not any(matches):
descr = None
else:
if len(matches) > 1:
warn(f"Multiple matches for referenced variable {val}")
match = matches[0]
descr = match.description

return {
"key": key,
"val": val,
"abbr": abbr,
"param": param,
"description": descr,
}

def _parent():
line = lines["parent"]
split = line.split()
return split[1]

return (
cls(**_subpkg(), parent=_parent())
if all(v for v in lines.values())
else None
)
Loading

0 comments on commit a85ba7e

Please sign in to comment.