From 7be48e87861b5b46050424efca184593fcfb702f Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 1 Sep 2023 09:38:30 -0400 Subject: [PATCH] [REF] Make progressbar description agnostic to operation --- curvlinops/_base.py | 19 ++++++++++++++----- curvlinops/jacobian.py | 10 ++++++---- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/curvlinops/_base.py b/curvlinops/_base.py index e655eb9..570a3ea 100644 --- a/curvlinops/_base.py +++ b/curvlinops/_base.py @@ -79,7 +79,9 @@ def __init__( self._device = self._infer_device(self._params) self._progressbar = progressbar - self._N_data = sum(X.shape[0] for (X, _) in self._loop_over_data()) + self._N_data = sum( + X.shape[0] for (X, _) in self._loop_over_data(desc="_N_data") + ) if check_deterministic: old_device = self._device @@ -206,7 +208,7 @@ def _matvec(self, x: ndarray) -> ndarray: x_list = self._preprocess(x) out_list = [zeros_like(x) for x in x_list] - for X, y in self._loop_over_data(): + for X, y in self._loop_over_data(desc="_matvec"): normalization_factor = self._get_normalization_factor(X, y) for mat_x, current in zip(out_list, self._matvec_batch(X, y, x_list)): @@ -266,16 +268,23 @@ def _postprocess(self, x_list: List[Tensor]) -> ndarray: """ return self.flatten_and_concatenate(x_list).cpu().numpy() - def _loop_over_data(self) -> Iterable[Tuple[Tensor, Tensor]]: + def _loop_over_data( + self, desc: Optional[str] = None + ) -> Iterable[Tuple[Tensor, Tensor]]: """Yield batches of the data set, loaded to the correct device. + Args: + desc: Description for the progress bar. Will be ignored if progressbar is + disabled. + Yields: Mini-batches ``(X, y)``. """ data_iter = iter(self._data) if self._progressbar: - data_iter = tqdm(data_iter, desc="matvec") + desc = f"{self.__class__.__name__}{'' if desc is None else f'.{desc}'}" + data_iter = tqdm(data_iter, desc=desc) for X, y in data_iter: X, y = X.to(self._device), y.to(self._device) @@ -298,7 +307,7 @@ def gradient_and_loss(self) -> Tuple[List[Tensor], Tensor]: total_loss = tensor([0.0], device=self._device) total_grad = [zeros_like(p) for p in self._params] - for X, y in self._loop_over_data(): + for X, y in self._loop_over_data(desc="gradient_and_loss"): loss = self._loss_func(self._model_func(X), y) normalization_factor = self._get_normalization_factor(X, y) diff --git a/curvlinops/jacobian.py b/curvlinops/jacobian.py index 96dfd59..b976644 100644 --- a/curvlinops/jacobian.py +++ b/curvlinops/jacobian.py @@ -88,7 +88,8 @@ def _check_deterministic(self): with no_grad(): for (X1, y1), (X2, y2) in zip( - self._loop_over_data(), self._loop_over_data() + self._loop_over_data(desc="_check_deterministic_data_pred"), + self._loop_over_data(desc="_check_deterministic_data_pred2"), ): pred1, y1 = self._model_func(X1).cpu().numpy(), y1.cpu().numpy() pred2, y2 = self._model_func(X2).cpu().numpy(), y2.cpu().numpy() @@ -117,7 +118,7 @@ def _matvec(self, x: ndarray) -> ndarray: jvp(self._model_func(X), self._params, x_list, retain_graph=False)[ 0 ].flatten(start_dim=1) - for X, _ in self._loop_over_data() + for X, _ in self._loop_over_data(desc="_matvec") ] return self._postprocess(out_list) @@ -212,7 +213,8 @@ def _check_deterministic(self): with no_grad(): for (X1, y1), (X2, y2) in zip( - self._loop_over_data(), self._loop_over_data() + self._loop_over_data(desc="_check_deterministic_data_pred1"), + self._loop_over_data(desc="_check_deterministic_data_pred2"), ): pred1, y1 = self._model_func(X1).cpu().numpy(), y1.cpu().numpy() pred2, y2 = self._model_func(X2).cpu().numpy(), y2.cpu().numpy() @@ -240,7 +242,7 @@ def _matvec(self, x: ndarray) -> ndarray: out_list = [zeros_like(p) for p in self._params] processed = 0 - for X, _ in self._loop_over_data(): + for X, _ in self._loop_over_data(desc="_matvec"): pred = self._model_func(X) v = x_torch[processed : processed + pred.numel()].reshape_as(pred) processed += pred.numel()