Skip to content

Commit

Permalink
adding type hints and adjusting normalize_inputs conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
yalsaffar committed Oct 14, 2024
1 parent bdaa99e commit 01cfbb5
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(

self.name = name

def normalize_inputs(self, x, y):
def normalize_inputs(self, x:torch.Tensor, y:torch.Tensor):
"""converts inputs into normalized format for this strategy
Args:
Expand All @@ -193,14 +193,10 @@ def normalize_inputs(self, x, y):
if x.shape == self.event_shape:
x = x[None, :]

if self.x is None:
x = x
else:
if self.x is not None:
x = torch.cat((self.x, x), dim=0)

if self.y is None:
y = y
else:
if self.y is not None:
y = torch.cat((self.y, y), dim=0)

# Ensure the correct dtype
Expand Down Expand Up @@ -313,7 +309,7 @@ def n_trials(self):
)
return self.min_asks

def add_data(self, x, y):
def add_data(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]):
"""
Adds new data points to the strategy, and normalizes the inputs.
Expand Down

0 comments on commit 01cfbb5

Please sign in to comment.