diff --git a/jpc/_core/_grads.py b/jpc/_core/_grads.py index 893af70..7c77d1d 100644 --- a/jpc/_core/_grads.py +++ b/jpc/_core/_grads.py @@ -39,7 +39,7 @@ def _neg_activity_grad( """ generator, output, input = args - dFdzs = grad(pc_energy_fn, argnums=1)( + dFdzs = grad(energy_fn, argnums=1)( generator, activities, output,