diff --git a/mushroom_rl/approximators/parametric/linear.py b/mushroom_rl/approximators/parametric/linear.py index 55c4b0fc..e93df07e 100644 --- a/mushroom_rl/approximators/parametric/linear.py +++ b/mushroom_rl/approximators/parametric/linear.py @@ -67,9 +67,7 @@ def predict(self, x, **predict_params): """ phi = np.atleast_2d(self.phi(x)) - prediction = np.ones((phi.shape[0], self._w.shape[0])) - for i, phi_i in enumerate(phi): - prediction[i] = phi_i.dot(self._w.T) + prediction = np.matmul(phi, self._w.T) return prediction