Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor PackageContainer composition over inheritance #2324

Merged
merged 7 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 22 additions & 23 deletions autotest/regression/test_mf6.py
Original file line number Diff line number Diff line change
Expand Up @@ -4497,22 +4497,21 @@ def test006_2models_mvr(function_tmpdir, example_data_path):
exg_pkg.exchangedata.set_data(exg_data)

# test getting packages
pkg_dict = parent_model.package_dict
assert len(pkg_dict) == 6
pkg_names = parent_model.package_names
assert len(pkg_names) == 6
pkg_list = parent_model.get_package()
assert len(pkg_list) == 6
# confirm that this is a copy of the original dictionary with references
# to the packages
del pkg_dict[pkg_names[0]]
assert len(pkg_dict) == 5
pkg_dict = parent_model.package_dict
assert len(pkg_dict) == 6

old_val = pkg_dict["dis"].nlay.get_data()
pkg_dict["dis"].nlay = 22
pkg_dict = parent_model.package_dict
assert pkg_dict["dis"].nlay.get_data() == 22
pkg_dict["dis"].nlay = old_val
del pkg_list[0]
assert len(pkg_list) == 5
pkg_list = parent_model.get_package()
assert len(pkg_list) == 6

dis_pkg = parent_model.get_package("dis")
old_val = dis_pkg.nlay.get_data()
dis_pkg.nlay = 22
pkg_list = parent_model.get_package()
assert dis_pkg.nlay.get_data() == 22
dis_pkg.nlay = old_val

# write simulation again
save_folder = function_tmpdir / "save"
Expand Down Expand Up @@ -4560,8 +4559,8 @@ def test006_2models_mvr(function_tmpdir, example_data_path):
model = sim.get_model(model_name)
for package in model_package_check:
assert (
package in model.package_type_dict
or package in sim.package_type_dict
model.get_package(package, type_only=True) is not None
or sim.get_package(package, type_only=True) is not None
) == (package in load_only or f"{package}6" in load_only)
assert (len(sim._exchange_files) > 0) == (
"gwf6-gwf6" in load_only or "gwf-gwf" in load_only
Expand All @@ -4577,10 +4576,10 @@ def test006_2models_mvr(function_tmpdir, example_data_path):
)
model_parent = sim.get_model("parent")
model_child = sim.get_model("child")
assert "oc" not in model_parent.package_type_dict
assert "oc" in model_child.package_type_dict
assert "npf" in model_parent.package_type_dict
assert "npf" not in model_child.package_type_dict
assert model_parent.get_package("oc") is None
assert model_child.get_package("oc") is not None
assert model_parent.get_package("npf") is not None
assert model_child.get_package("npf") is None

# test running a runnable load_only case
sim = MFSimulation.load(
Expand Down Expand Up @@ -4652,9 +4651,9 @@ def test001e_uzf_3lay(function_tmpdir, example_data_path):
sim.set_sim_path(function_tmpdir)
model = sim.get_model()
for package in model_package_check:
assert (package in model.package_type_dict) == (
package in load_only or f"{package}6" in load_only
)
assert (
model.get_package(package, type_only=True) is not None
) == (package in load_only or f"{package}6" in load_only)
# test running a runnable load_only case
sim = MFSimulation.load(
model_name, "mf6", "mf6", pth, load_only=load_only_lists[0]
Expand Down
12 changes: 2 additions & 10 deletions autotest/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ def model_is_copy(m1, m2):
if k in [
"_packagelist",
"_package_paths",
"package_key_dict",
"package_type_dict",
"package_name_dict",
"package_filename_dict",
"_ftype_num_dict",
]:
continue
Expand Down Expand Up @@ -97,17 +93,13 @@ def package_is_copy(pk1, pk2):
if k in [
"_child_package_groups",
"_data_list",
"_packagelist",
"_simulation_data",
"simulation_data",
"blocks",
"dimensions",
"package_key_dict",
"package_name_dict",
"package_filename_dict",
"package_type_dict",
"post_block_comments",
"simulation_data",
"structure",
"_package_container",
]:
continue
elif isinstance(v, MFPackage):
Expand Down
18 changes: 9 additions & 9 deletions autotest/test_model_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,19 +467,19 @@ def test_empty_packages(function_tmpdir):
m0 = new_sim.get_model(f"{base_name}_0")
m1 = new_sim.get_model(f"{base_name}_1")

if "chd_0" in m0.package_dict:
raise AssertionError(f"Empty CHD file written to {base_name}_0 model")

if "wel_0" in m1.package_dict:
raise AssertionError(f"Empty WEL file written to {base_name}_1 model")
assert not m0.get_package(
name="chd_0"
), f"Empty CHD file written to {base_name}_0 model"
assert not m1.get_package(
name="wel_0"
), f"Empty WEL file written to {base_name}_1 model"

mvr_status0 = m0.sfr.mover.array
mvr_status1 = m0.sfr.mover.array

if not mvr_status0 or not mvr_status1:
raise AssertionError(
"Mover status being overwritten in options splitting"
)
assert (
mvr_status0 and mvr_status1
), "Mover status being overwritten in options splitting"


@requires_exe("mf6")
Expand Down
43 changes: 13 additions & 30 deletions flopy/mf6/mfbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pathlib import Path
from shutil import copyfile
from typing import Union
from warnings import warn


# internal handled exceptions
Expand Down Expand Up @@ -454,24 +453,13 @@ class PackageContainer:
modflow_models = []
models_by_type = {}

def __init__(self, simulation_data, name):
self.type = "PackageContainer"
self.simulation_data = simulation_data
self.name = name
self._packagelist = []
def __init__(self, simulation_data):
self._simulation_data = simulation_data
deltamarnix marked this conversation as resolved.
Show resolved Hide resolved
self.packagelist = []
self.package_type_dict = {}
self.package_name_dict = {}
self.package_filename_dict = {}

@property
def package_key_dict(self):
warnings.warn(
"package_key_dict has been deprecated, use "
"package_type_dict instead",
category=DeprecationWarning,
)
return self.package_type_dict

@staticmethod
def package_list():
"""Static method that returns the list of available packages.
Expand Down Expand Up @@ -554,9 +542,9 @@ def package_names(self):
"""Returns a list of package names."""
return list(self.package_name_dict.keys())

def _add_package(self, package, path):
def add_package(self, package):
# put in packages list and update lookup dictionaries
self._packagelist.append(package)
self.packagelist.append(package)
if package.package_name is not None:
self.package_name_dict[package.package_name.lower()] = package
if package.filename is not None:
Expand All @@ -565,9 +553,9 @@ def _add_package(self, package, path):
self.package_type_dict[package.package_type.lower()] = []
self.package_type_dict[package.package_type.lower()].append(package)

def _remove_package(self, package):
if package in self._packagelist:
self._packagelist.remove(package)
def remove_package(self, package):
if package in self.packagelist:
self.packagelist.remove(package)
if (
package.package_name is not None
and package.package_name.lower() in self.package_name_dict
Expand All @@ -587,7 +575,7 @@ def _remove_package(self, package):

# collect keys of items to be removed from main dictionary
items_to_remove = []
for key in self.simulation_data.mfdata:
for key in self._simulation_data.mfdata:
is_subkey = True
for pitem, ditem in zip(package.path, key):
if pitem != ditem:
Expand All @@ -598,7 +586,7 @@ def _remove_package(self, package):

# remove items from main dictionary
for key in items_to_remove:
del self.simulation_data.mfdata[key]
del self._simulation_data.mfdata[key]

def _rename_package(self, package, new_name):
# fix package_name_dict key
Expand All @@ -609,7 +597,7 @@ def _rename_package(self, package, new_name):
del self.package_name_dict[package.package_name.lower()]
self.package_name_dict[new_name.lower()] = package
# get keys to fix in main dictionary
main_dict = self.simulation_data.mfdata
main_dict = self._simulation_data.mfdata
items_to_fix = []
for key in main_dict:
is_subkey = True
Expand Down Expand Up @@ -648,7 +636,7 @@ def get_package(self, name=None, type_only=False, name_only=False):

"""
if name is None:
return self._packagelist[:]
return self.packagelist[:]

# search for full package name
if name.lower() in self.package_name_dict and not type_only:
Expand All @@ -669,7 +657,7 @@ def get_package(self, name=None, type_only=False, name_only=False):

# search for partial and case-insensitive package name
if not type_only:
for pp in self._packagelist:
for pp in self.packagelist:
if pp.package_name is not None:
# get first package of the type requested
package_name = pp.package_name.lower()
Expand All @@ -680,11 +668,6 @@ def get_package(self, name=None, type_only=False, name_only=False):

return None

def register_package(self, package):
"""Base method for registering a package. Should be overridden."""
path = (package.package_name,)
return (path, None)

@staticmethod
def _load_only_dict(load_only):
if load_only is None:
Expand Down
Loading
Loading