diff --git a/jpc/_test.py b/jpc/_test.py index 83c1096..54a295d 100644 --- a/jpc/_test.py +++ b/jpc/_test.py @@ -128,7 +128,7 @@ def test_hpc( atol=1e-3 ), dt: Union[float, int] = None -): +) -> Tuple[Scalar, Scalar, Scalar, Array]: """Computes test metrics for hybrid predictive coding. Calculates input accuracy of (i) amortiser, (ii) generative, and (ii) @@ -171,7 +171,7 @@ def test_hpc( amort_preds = amort_activities[0] hpc_preds = solve_pc_activities( network=generator, - activities=amort_activities[1:], + activities=amort_activities, output=output, solver=solver, n_iters=n_iters,