From 09b1d59a1f66703ea954db3fae7031d24fd3d2f7 Mon Sep 17 00:00:00 2001 From: Yousif Alsaffar Date: Thu, 17 Oct 2024 07:43:03 -0700 Subject: [PATCH] Fix Strategy class to ensure consistent tensor operations for data normalization (#403) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: This PR addresses the second part of issue https://github.com/facebookresearch/aepsych/issues/365, focusing on the `Strategy` class and how data is added and normalized, transitioning the process to use tensors instead of NumPy operations. The changes were made specifically within the `normalize_inputs` method of the `Strategy` class. Previously, this method had mismatched docstrings indicating `np.array` usage. Now, it consistently accepts and returns tensors, performing all operations within tensors. The `normalize_inputs` method is called in `add_data()` (where the confusion arises), as the data passed can vary (either tensors or `np.array`). To resolve this, the method now acts as the first step, accepting both formats and then converting everything to tensors for consistent operations (model fitting later on). It’s also crucial to ensure the data type is `float64`, as `gpytorch` does not support other data types. Additionally, a detailed docstring was added to clarify the method's expectations and ensure its proper use going forward. Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/403 Reviewed By: crasanders Differential Revision: D64343236 Pulled By: JasonKChow fbshipit-source-id: 413077605f4fa46b82405897c713cbc62b58a3f3 --- aepsych/strategy.py | 53 ++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 704dd09fd..9ed40e196 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -139,8 +139,8 @@ def __init__( lb=self.lb, ub=self.ub, size=self._n_eval_points ) - self.x = None - self.y = None + self.x: Optional[torch.Tensor] = None + self.y: Optional[torch.Tensor] = None self.n = 0 self.min_asks = min_asks self._count = 0 @@ -170,38 +170,41 @@ def __init__( self.name = name - def normalize_inputs(self, x, y): + def normalize_inputs(self, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: """converts inputs into normalized format for this strategy Args: - x (np.ndarray): training inputs - y (np.ndarray): training outputs + x (torch.Tensor): training inputs + y (torch.Tensor): training outputs Returns: - x (np.ndarray): training inputs, normalized - y (np.ndarray): training outputs, normalized + x (torch.Tensor): training inputs, normalized + y (torch.Tensor): training outputs, normalized n (int): number of observations """ assert ( x.shape == self.event_shape or x.shape[1:] == self.event_shape ), f"x shape should be {self.event_shape} or batch x {self.event_shape}, instead got {x.shape}" - + + # Handle scalar y values + if y.ndim == 0: + y = y.unsqueeze(0) + if x.shape == self.event_shape: x = x[None, :] - if self.x is None: - x = np.r_[x] - else: - x = np.r_[self.x, x] + if self.x is not None: + x = torch.cat((self.x, x), dim=0) - if self.y is None: - y = np.r_[y] - else: - y = np.r_[self.y, y] + if self.y is not None: + y = torch.cat((self.y, y), dim=0) + # Ensure the correct dtype + x = x.to(torch.float64) + y = y.to(torch.float64) n = y.shape[0] - return torch.Tensor(x), torch.Tensor(y), n + return x, y, n # TODO: allow user to pass in generator options @ensure_model_is_fresh @@ -306,7 +309,21 @@ def n_trials(self): ) return self.min_asks - def add_data(self, x, y): + def add_data(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]): + """ + Adds new data points to the strategy, and normalizes the inputs. + + Args: + x (torch.Tensor, np.ndarray): The input data points. Can be a PyTorch tensor or NumPy array. + y (torch.Tensor, np.ndarray): The output data points. Can be a PyTorch tensor or NumPy array. + + """ + # Necessary as sometimes the data is passed in as numpy arrays or torch tensors. + if not isinstance(y, torch.Tensor): + y = torch.tensor(y, dtype=torch.float64) + if not isinstance(x, torch.Tensor): + x = torch.tensor(x, dtype=torch.float64) + self.x, self.y, self.n = self.normalize_inputs(x, y) self._model_is_fresh = False