Skip to content

Commit

Permalink
feat(pt): support use_aparam_as_mask for pt backend (#4246)
Browse files Browse the repository at this point in the history
support `use_aparam_as_mask` for pt backend

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced `use_aparam_as_mask` parameter in `GeneralFitting`,
`InvarFitting`, and `EnerFitting` classes, allowing users to
conditionally exclude atomic parameters from fitting processes.
- Added `seed` parameter to `InvarFitting` for enhanced control over
randomness.
- New test method `test_use_aparam_as_mask` in `TestInvarFitting` to
validate behavior based on the new parameter.

- **Bug Fixes**
	- Improved error handling for `use_aparam_as_mask` in various classes.

- **Tests**
- Enhanced parameterization in multiple test classes to accommodate new
features related to atomic parameters.
- Updated test methods in `TestInvarFitting` to include
`use_aparam_as_mask` for comprehensive testing.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
ChiahsinChu authored Oct 26, 2024
1 parent fa61d69 commit 5394854
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 28 deletions.
8 changes: 6 additions & 2 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ def __init__(
else:
self.aparam_avg, self.aparam_inv_std = None, None
# init networks
in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam
in_dim = (
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
)
self.nets = NetworkCollection(
1 if not self.mixed_types else 0,
self.ntypes,
Expand Down Expand Up @@ -401,7 +405,7 @@ def _call_common(
axis=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
assert aparam is not None, "aparam should not be None"
if aparam.shape[-1] != self.numb_aparam:
raise ValueError(
Expand Down
4 changes: 0 additions & 4 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ def __init__(
raise NotImplementedError("tot_ener_zero is not implemented")
if spin is not None:
raise NotImplementedError("spin is not implemented")
if use_aparam_as_mask:
raise NotImplementedError("use_aparam_as_mask is not implemented")
if use_aparam_as_mask:
raise NotImplementedError("use_aparam_as_mask is not implemented")
if layer_name is not None:
raise NotImplementedError("layer_name is not implemented")

Expand Down
15 changes: 11 additions & 4 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class GeneralFitting(Fitting):
length as `ntypes` signaling if or not removing the vaccum contribution for the atom types in the list.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.
use_aparam_as_mask: bool
If True, the aparam will not be used in fitting net for embedding.
"""

def __init__(
Expand All @@ -147,6 +149,7 @@ def __init__(
trainable: Union[bool, list[bool]] = True,
remove_vaccum_contribution: Optional[list[bool]] = None,
type_map: Optional[list[str]] = None,
use_aparam_as_mask: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -164,6 +167,7 @@ def __init__(
self.rcond = rcond
self.seed = seed
self.type_map = type_map
self.use_aparam_as_mask = use_aparam_as_mask
# order matters, should be place after the assignment of ntypes
self.reinit_exclude(exclude_types)
self.trainable = trainable
Expand Down Expand Up @@ -208,7 +212,11 @@ def __init__(
else:
self.aparam_avg, self.aparam_inv_std = None, None

in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam
in_dim = (
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
)

self.filter_layers = NetworkCollection(
1 if not self.mixed_types else 0,
Expand Down Expand Up @@ -293,13 +301,12 @@ def serialize(self) -> dict:
# "trainable": self.trainable ,
# "atom_ener": self.atom_ener ,
# "layer_name": self.layer_name ,
# "use_aparam_as_mask": self.use_aparam_as_mask ,
# "spin": self.spin ,
## NOTICE: not supported by far
"tot_ener_zero": False,
"trainable": [self.trainable] * (len(self.neuron) + 1),
"layer_name": None,
"use_aparam_as_mask": False,
"use_aparam_as_mask": self.use_aparam_as_mask,
"spin": None,
}

Expand Down Expand Up @@ -441,7 +448,7 @@ def _forward_common(
dim=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
assert aparam is not None, "aparam should not be None"
assert self.aparam_avg is not None
assert self.aparam_inv_std is not None
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class InvarFitting(GeneralFitting):
The `set_davg_zero` key in the descrptor should be set.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.
use_aparam_as_mask: bool
If True, the aparam will not be used in fitting net for embedding.
"""

def __init__(
Expand All @@ -99,6 +100,7 @@ def __init__(
exclude_types: list[int] = [],
atom_ener: Optional[list[Optional[torch.Tensor]]] = None,
type_map: Optional[list[str]] = None,
use_aparam_as_mask: bool = False,
**kwargs,
):
self.dim_out = dim_out
Expand All @@ -122,6 +124,7 @@ def __init__(
if atom_ener is None or len([x for x in atom_ener if x is not None]) == 0
else [x is not None for x in atom_ener],
type_map=type_map,
use_aparam_as_mask=use_aparam_as_mask,
**kwargs,
)

Expand Down
32 changes: 21 additions & 11 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def _build_lower(
ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam])
ext_fparam = tf.cast(ext_fparam, self.fitting_precision)
layer = tf.concat([layer, ext_fparam], axis=1)
if aparam is not None:
if aparam is not None and not self.use_aparam_as_mask:
ext_aparam = tf.slice(
aparam,
[0, start_index * self.numb_aparam],
Expand Down Expand Up @@ -561,7 +561,7 @@ def build(
trainable=False,
initializer=tf.constant_initializer(self.fparam_inv_std),
)
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
t_aparam_avg = tf.get_variable(
"t_aparam_avg",
self.numb_aparam,
Expand All @@ -576,6 +576,13 @@ def build(
trainable=False,
initializer=tf.constant_initializer(self.aparam_inv_std),
)
else:
t_aparam_avg = tf.zeros(
self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION
)
t_aparam_istd = tf.ones(
self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION
)

inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt])
if len(self.atom_ener):
Expand All @@ -602,12 +609,11 @@ def build(
fparam = (fparam - t_fparam_avg) * t_fparam_istd

aparam = None
if not self.use_aparam_as_mask:
if self.numb_aparam > 0:
aparam = input_dict["aparam"]
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
aparam = (aparam - t_aparam_avg) * t_aparam_istd
aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
aparam = input_dict["aparam"]
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
aparam = (aparam - t_aparam_avg) * t_aparam_istd
aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])

atype_nall = tf.reshape(atype, [-1, natoms[1]])
self.atype_nloc = tf.slice(
Expand Down Expand Up @@ -783,7 +789,7 @@ def init_variables(
self.fparam_inv_std = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_fparam_istd"
)
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
self.aparam_avg = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_aparam_avg"
)
Expand Down Expand Up @@ -883,7 +889,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
if fitting.numb_fparam > 0:
fitting.fparam_avg = data["@variables"]["fparam_avg"]
fitting.fparam_inv_std = data["@variables"]["fparam_inv_std"]
if fitting.numb_aparam > 0:
if fitting.numb_aparam > 0 and not fitting.use_aparam_as_mask:
fitting.aparam_avg = data["@variables"]["aparam_avg"]
fitting.aparam_inv_std = data["@variables"]["aparam_inv_std"]
return fitting
Expand Down Expand Up @@ -922,7 +928,11 @@ def serialize(self, suffix: str = "") -> dict:
"nets": self.serialize_network(
ntypes=self.ntypes,
ndim=0 if self.mixed_types else 1,
in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam,
in_dim=(
self.dim_descrpt
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
),
neuron=self.n_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.resnet_dt,
Expand Down
8 changes: 7 additions & 1 deletion source/tests/consistent/fitting/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class FittingTest:
"""Useful utilities for descriptor tests."""

def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix):
def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, aparam, suffix):
t_inputs = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_inputs")
t_natoms = tf.placeholder(tf.int32, natoms.shape, name="i_natoms")
t_atype = tf.placeholder(tf.int32, [None], name="i_atype")
Expand All @@ -30,6 +30,12 @@ def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix):
)
extras["fparam"] = t_fparam
feed_dict[t_fparam] = fparam
if aparam is not None:
t_aparam = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION, [None, None], name="i_aparam"
)
extras["aparam"] = t_aparam
feed_dict[t_aparam] = aparam
t_out = obj.build(
t_inputs,
t_natoms,
Expand Down
22 changes: 22 additions & 0 deletions source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
("float64", "float32"), # precision
(True, False), # mixed_types
(0, 1), # numb_fparam
(0, 1), # numb_aparam
(10, 20), # numb_dos
)
class TestDOS(CommonTest, FittingTest, unittest.TestCase):
Expand All @@ -68,13 +69,15 @@ def data(self) -> dict:
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return {
"neuron": [5, 5, 5],
"resnet_dt": resnet_dt,
"precision": precision,
"numb_fparam": numb_fparam,
"numb_aparam": numb_aparam,
"seed": 20240217,
"numb_dos": numb_dos,
}
Expand All @@ -86,6 +89,7 @@ def skip_pt(self) -> bool:
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return CommonTest.skip_pt
Expand Down Expand Up @@ -115,6 +119,9 @@ def setUp(self):
# inconsistent if not sorted
self.atype.sort()
self.fparam = -np.ones((1,), dtype=GLOBAL_NP_FLOAT_PRECISION)
self.aparam = np.zeros_like(
self.atype, dtype=GLOBAL_NP_FLOAT_PRECISION
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
Expand All @@ -123,6 +130,7 @@ def addtional_data(self) -> dict:
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return {
Expand All @@ -137,6 +145,7 @@ def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return self.build_tf_fitting(
Expand All @@ -145,6 +154,7 @@ def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
self.natoms,
self.atype,
self.fparam if numb_fparam else None,
self.aparam if numb_aparam else None,
suffix,
)

Expand All @@ -154,6 +164,7 @@ def eval_pt(self, pt_obj: Any) -> Any:
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return (
Expand All @@ -163,6 +174,9 @@ def eval_pt(self, pt_obj: Any) -> Any:
fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE)
if numb_fparam
else None,
aparam=torch.from_numpy(self.aparam).to(device=PT_DEVICE)
if numb_aparam
else None,
)["dos"]
.detach()
.cpu()
Expand All @@ -175,12 +189,14 @@ def eval_dp(self, dp_obj: Any) -> Any:
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return dp_obj(
self.inputs,
self.atype.reshape(1, -1),
fparam=self.fparam if numb_fparam else None,
aparam=self.aparam if numb_aparam else None,
)["dos"]

def eval_jax(self, jax_obj: Any) -> Any:
Expand All @@ -189,13 +205,15 @@ def eval_jax(self, jax_obj: Any) -> Any:
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
fparam=jnp.asarray(self.fparam) if numb_fparam else None,
aparam=jnp.asarray(self.aparam) if numb_aparam else None,
)["dos"]
)

Expand All @@ -206,13 +224,15 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None,
aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None,
)["dos"]
)

Expand All @@ -230,6 +250,7 @@ def rtol(self) -> float:
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float64":
Expand All @@ -247,6 +268,7 @@ def atol(self) -> float:
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float64":
Expand Down
Loading

0 comments on commit 5394854

Please sign in to comment.