diff --git a/api/tests/eigh.py b/api/tests/eigh.py index 186be5734b..611eb5a3df 100644 --- a/api/tests/eigh.py +++ b/api/tests/eigh.py @@ -31,7 +31,7 @@ def build_graph(self, config): self.feed_list = [x] self.fetch_list = [out_w, out_v] if config.backward: - self.append_gradients(out_w, [x]) + self.append_gradients(out_w.sum() + paddle.abs(out_v).sum(), [x]) @benchmark_registry.register("eigh") @@ -43,4 +43,4 @@ def build_graph(self, config): self.feed_list = [x] self.fetch_list = [out_w, out_v] if config.backward: - self.append_gradients(out_w, [x]) + self.append_gradients(out_w.sum() + torch.abs(out_v).sum(), [x])