Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Strategy class to ensure consistent tensor operations for data normalization #403

Closed
53 changes: 35 additions & 18 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def __init__(
lb=self.lb, ub=self.ub, size=self._n_eval_points
)

self.x = None
self.y = None
self.x: Optional[torch.Tensor] = None
self.y: Optional[torch.Tensor] = None
self.n = 0
self.min_asks = min_asks
self._count = 0
Expand Down Expand Up @@ -170,38 +170,41 @@ def __init__(

self.name = name

def normalize_inputs(self, x, y):
def normalize_inputs(self, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""converts inputs into normalized format for this strategy

Args:
x (np.ndarray): training inputs
y (np.ndarray): training outputs
x (torch.Tensor): training inputs
yalsaffar marked this conversation as resolved.
Show resolved Hide resolved
y (torch.Tensor): training outputs

Returns:
x (np.ndarray): training inputs, normalized
y (np.ndarray): training outputs, normalized
x (torch.Tensor): training inputs, normalized
y (torch.Tensor): training outputs, normalized
n (int): number of observations
"""
assert (
x.shape == self.event_shape or x.shape[1:] == self.event_shape
), f"x shape should be {self.event_shape} or batch x {self.event_shape}, instead got {x.shape}"


# Handle scalar y values
if y.ndim == 0:
y = y.unsqueeze(0)

if x.shape == self.event_shape:
x = x[None, :]

if self.x is None:
x = np.r_[x]
else:
x = np.r_[self.x, x]
if self.x is not None:
x = torch.cat((self.x, x), dim=0)

if self.y is None:
y = np.r_[y]
else:
y = np.r_[self.y, y]
if self.y is not None:
y = torch.cat((self.y, y), dim=0)

# Ensure the correct dtype
x = x.to(torch.float64)
yalsaffar marked this conversation as resolved.
Show resolved Hide resolved
y = y.to(torch.float64)
n = y.shape[0]

return torch.Tensor(x), torch.Tensor(y), n
return x, y, n

# TODO: allow user to pass in generator options
@ensure_model_is_fresh
Expand Down Expand Up @@ -306,7 +309,21 @@ def n_trials(self):
)
return self.min_asks

def add_data(self, x, y):
def add_data(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]):
"""
JasonKChow marked this conversation as resolved.
Show resolved Hide resolved
Adds new data points to the strategy, and normalizes the inputs.

Args:
x (torch.Tensor, np.ndarray): The input data points. Can be a PyTorch tensor or NumPy array.
y (torch.Tensor, np.ndarray): The output data points. Can be a PyTorch tensor or NumPy array.

"""
# Necessary as sometimes the data is passed in as numpy arrays or torch tensors.
if not isinstance(y, torch.Tensor):
y = torch.tensor(y, dtype=torch.float64)
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, dtype=torch.float64)

self.x, self.y, self.n = self.normalize_inputs(x, y)
self._model_is_fresh = False

Expand Down
Loading