Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 27, 2024
1 parent 5d8c96a commit feda81b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def forward(
info_tensor = torch.tensor(info, dtype=self.prec, device="cpu")
gg_t = gg_t.reshape(-1, gg_t.size(-1))
# Convert all tensors to the required precision at once
ss, rr, gg_t = [t.to(self.prec) for t in (ss, rr, gg_t)]
ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t))
xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten(
tensor_data.contiguous(),
info_tensor.contiguous(),
Expand Down
13 changes: 8 additions & 5 deletions deepmd/pt/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,10 +557,13 @@ def _layer_1(self, x, w, b):
# Change the embedding net range to sw / min_nbor_dist
def _get_env_mat_range(self, min_nbor_dist):
sw = self._spline5_switch(min_nbor_dist, self.rcut_smth, self.rcut)
if isinstance(self.descrpt, (
deepmd.pt.model.descriptor.DescrptSeA,
deepmd.pt.model.descriptor.DescrptDPA1,
)):
if isinstance(
self.descrpt,
(
deepmd.pt.model.descriptor.DescrptSeA,
deepmd.pt.model.descriptor.DescrptDPA1,
),
):
lower = -self.davg[:, 0] / self.dstd[:, 0]
upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0]
elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeT):
Expand Down Expand Up @@ -828,7 +831,7 @@ def grad(xbar, y, functype): # functype=tanh, gelu, ..
return 1.0 - 1.0 / (1.0 + np.exp(xbar))
elif functype == 6:
return y * (1 - y)

Check warning on line 833 in deepmd/pt/utils/tabulate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/tabulate.py#L826-L833

Added lines #L826 - L833 were not covered by tests

raise ValueError(f"Unsupported function type: {functype}")

Check warning on line 835 in deepmd/pt/utils/tabulate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/tabulate.py#L835

Added line #L835 was not covered by tests


Expand Down

0 comments on commit feda81b

Please sign in to comment.