From 05842da65cafd53318b693efc2c70be080ab8def Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 23 Jul 2024 23:35:02 +0800 Subject: [PATCH 1/3] fix(pt): fix `get_dim` for `DescrptDPA1Compat` --- deepmd/tf/descriptor/se_atten.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 13976a84e1..e255a0541c 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -2194,6 +2194,14 @@ def __init__( else: self.embd_input_dim = 1 + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + return ( + super().get_dim_out() + self.tebd_dim + if self.concat_output_tebd + else super().get_dim_out() + ) + def build( self, coord_: tf.Tensor, From f3f6b7e4b3d131f85894a2a91afc483f63de3651 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 19 Jul 2024 04:46:29 -0400 Subject: [PATCH 2/3] test get_dim_out Signed-off-by: Jinzhe Zeng --- source/tests/consistent/descriptor/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 13ceef84ab..74fc3d9b07 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -49,6 +49,8 @@ def build_tf_descriptor(self, obj, natoms, coords, atype, box, suffix): {}, suffix=suffix, ) + # ensure get_dim_out gives the correct shape + t_des = tf.reshape(t_des, [1, natoms[0], obj.get_dim_out()]) return [t_des], { t_coord: coords, t_type: atype, From 98c2b2b106949a0e5014529a8586b74932345512 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 25 Jul 2024 23:33:48 +0800 Subject: [PATCH 3/3] Update se_atten.py --- deepmd/tf/descriptor/se_atten.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 7d61ca1581..37bcd7eea0 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -765,7 +765,14 @@ def _pass_filter( type_embedding=type_embedding, atype=atype, ) - layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0], self.get_dim_out()]) + layer = tf.reshape( + layer, + [ + tf.shape(inputs)[0], + natoms[0], + self.filter_neuron[-1] * self.n_axis_neuron, + ], + ) qmat = tf.reshape( qmat, [tf.shape(inputs)[0], natoms[0], self.get_dim_rot_mat_1() * 3] )