From e9451c303b68a63102a6598c30e472ac4adb445a Mon Sep 17 00:00:00 2001 From: Katze2664 <40237250+Katze2664@users.noreply.github.com> Date: Tue, 28 May 2024 18:46:53 +1000 Subject: [PATCH] Update linear.py Simpler prediction using np.matmul --- mushroom_rl/approximators/parametric/linear.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mushroom_rl/approximators/parametric/linear.py b/mushroom_rl/approximators/parametric/linear.py index 55c4b0fc..e93df07e 100644 --- a/mushroom_rl/approximators/parametric/linear.py +++ b/mushroom_rl/approximators/parametric/linear.py @@ -67,9 +67,7 @@ def predict(self, x, **predict_params): """ phi = np.atleast_2d(self.phi(x)) - prediction = np.ones((phi.shape[0], self._w.shape[0])) - for i, phi_i in enumerate(phi): - prediction[i] = phi_i.dot(self._w.T) + prediction = np.matmul(phi, self._w.T) return prediction