diff --git a/aepsych/strategy.py b/aepsych/strategy.py index 704dd09fd..9ed40e196 100644 --- a/aepsych/strategy.py +++ b/aepsych/strategy.py @@ -139,8 +139,8 @@ def __init__( lb=self.lb, ub=self.ub, size=self._n_eval_points ) - self.x = None - self.y = None + self.x: Optional[torch.Tensor] = None + self.y: Optional[torch.Tensor] = None self.n = 0 self.min_asks = min_asks self._count = 0 @@ -170,38 +170,41 @@ def __init__( self.name = name - def normalize_inputs(self, x, y): + def normalize_inputs(self, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: """converts inputs into normalized format for this strategy Args: - x (np.ndarray): training inputs - y (np.ndarray): training outputs + x (torch.Tensor): training inputs + y (torch.Tensor): training outputs Returns: - x (np.ndarray): training inputs, normalized - y (np.ndarray): training outputs, normalized + x (torch.Tensor): training inputs, normalized + y (torch.Tensor): training outputs, normalized n (int): number of observations """ assert ( x.shape == self.event_shape or x.shape[1:] == self.event_shape ), f"x shape should be {self.event_shape} or batch x {self.event_shape}, instead got {x.shape}" - + + # Handle scalar y values + if y.ndim == 0: + y = y.unsqueeze(0) + if x.shape == self.event_shape: x = x[None, :] - if self.x is None: - x = np.r_[x] - else: - x = np.r_[self.x, x] + if self.x is not None: + x = torch.cat((self.x, x), dim=0) - if self.y is None: - y = np.r_[y] - else: - y = np.r_[self.y, y] + if self.y is not None: + y = torch.cat((self.y, y), dim=0) + # Ensure the correct dtype + x = x.to(torch.float64) + y = y.to(torch.float64) n = y.shape[0] - return torch.Tensor(x), torch.Tensor(y), n + return x, y, n # TODO: allow user to pass in generator options @ensure_model_is_fresh @@ -306,7 +309,21 @@ def n_trials(self): ) return self.min_asks - def add_data(self, x, y): + def add_data(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]): + """ + Adds new data points to the strategy, and normalizes the inputs. + + Args: + x (torch.Tensor, np.ndarray): The input data points. Can be a PyTorch tensor or NumPy array. + y (torch.Tensor, np.ndarray): The output data points. Can be a PyTorch tensor or NumPy array. + + """ + # Necessary as sometimes the data is passed in as numpy arrays or torch tensors. + if not isinstance(y, torch.Tensor): + y = torch.tensor(y, dtype=torch.float64) + if not isinstance(x, torch.Tensor): + x = torch.tensor(x, dtype=torch.float64) + self.x, self.y, self.n = self.normalize_inputs(x, y) self._model_is_fresh = False