Skip to content

Commit

Permalink
🐛 use perturb_continuous_data_extended from perturbations
Browse files Browse the repository at this point in the history
(as defined on developer branch)
- remove debug log statements here
- the version in "identify_associations.py" had a hardcorded number of parameters (5): num_features = 5
  • Loading branch information
Henry committed Jul 3, 2024
1 parent 5aa03ff commit 8b06298
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 109 deletions.
16 changes: 2 additions & 14 deletions src/move/data/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,11 @@ def perturb_continuous_data_extended(
datasets. Scaling is done per dataset, not per feature -> slightly different
stds feature to feature.
"""
logger = get_logger(__name__)
logger.debug("Inside perturb_extended, creating baseline dataset")

baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.con_shapes is not None
assert baseline_dataset.con_all is not None

logger.debug("Creating target_ics, splits, and slice")
target_idx = con_dataset_names.index(target_dataset_name) # dataset index
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx : target_idx + 2])
Expand All @@ -387,7 +385,6 @@ def perturb_continuous_data_extended(
perturbations_list = []

for i in range(num_features):
logger.debug(f"Getting perturbed dataset for feature {i}")
perturbed_con = baseline_dataset.con_all.clone()
target_dataset = perturbed_con[:, slice_]
# Change the desired feature value by:
Expand All @@ -403,7 +400,7 @@ def perturb_continuous_data_extended(
elif perturbation_type == "minus_std":
target_dataset[:, i] -= torch.FloatTensor([std_feat_val_list[i]])

# perturbations_list.append(target_dataset[:, i].numpy())
perturbations_list.append(target_dataset[:, i].numpy())

perturbed_dataset = MOVEDataset(
baseline_dataset.cat_all,
Expand All @@ -418,15 +415,6 @@ def perturb_continuous_data_extended(
batch_size=baseline_dataloader.batch_size,
)
dataloaders.append(perturbed_dataloader)
logger.debug("Finished perturb_continuous_data_extended function")

# Plot the perturbations for all features, collapsed in one plot:
# if output_subpath is not None:
# fig = plot_value_distributions(np.array(perturbations_list).transpose())
# fig_path = str(
# output_subpath / f"perturbation_distribution_{target_dataset_name}.png"
# )
# fig.savefig(fig_path)

# Plot the perturbations for all features, collapsed in one plot:
if output_subpath is not None:
Expand Down
97 changes: 2 additions & 95 deletions src/move/tasks/identify_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from move.core.typing import BoolArray, FloatArray, IntArray
from move.data import io
from move.data.dataloaders import MOVEDataset, make_dataloader
from move.data.perturbations import ( # perturb_continuous_data_extended,
from move.data.perturbations import (
ContinuousPerturbationType,
perturb_categorical_data,
perturb_continuous_data_extended,
)
from move.data.preprocessing import feature_stats, one_hot_encode_single
from move.models.vae import VAE
Expand All @@ -42,100 +43,6 @@
CONTINUOUS_TARGET_VALUE = ["minimum", "maximum", "plus_std", "minus_std"]


def perturb_continuous_data_extended(
baseline_dataloader: DataLoader,
con_dataset_names: list[str],
target_dataset_name: str,
perturbation_type: ContinuousPerturbationType,
output_subpath: Optional[Path] = None,
) -> list[DataLoader]:
logger = get_logger(__name__)

"""Add perturbations to continuous data. For each feature in the target
dataset, change the feature's value in all samples (in rows):
1,2) substituting this feature in all samples by the feature's minimum/maximum value
3,4) Adding/Substracting one standard deviation to the sample's feature value.
Args:
baseline_dataloader: Baseline dataloader
con_dataset_names: List of continuous dataset names
target_dataset_name: Target continuous dataset to perturb
perturbation_type: 'minimum', 'maximum', 'plus_std' or 'minus_std'.
output_subpath: path where the figure showing the perturbation will be saved
Returns:
- List of dataloaders containing all perturbed datasets
- Plot of the feature value distribution after the perturbation. Note that
all perturbations are collapsed into one single plot.
Note:
This function was created so that it could generalize to non-normalized
datasets. Scaling is done per dataset, not per feature -> slightly different
stds feature to feature.
"""
logger.debug("Inside perturb_extended, creating baseline dataset")
baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.con_shapes is not None
assert baseline_dataset.con_all is not None

logger.debug("Creating target_ics, splits, and slice")
target_idx = con_dataset_names.index(target_dataset_name) # dataset index
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx : target_idx + 2])

# num_features = baseline_dataset.con_shapes[target_idx]
# CHANGED THIS TO TRY IT. CHANGE LATER
num_features = 5
logger.debug(f"number of feature to perturb is {num_features}")
dataloaders = []
# perturbations_list = []
# Change below.
# num_features = 10

for i in range(num_features):
logger.debug(f"Getting perturbed dataset for feature {i}")
perturbed_con = baseline_dataset.con_all.clone()
target_dataset = perturbed_con[:, slice_]
# Change the desired feature value by:
min_feat_val_list, max_feat_val_list, std_feat_val_list = feature_stats(
target_dataset
)
if perturbation_type == "minimum":
target_dataset[:, i] = torch.FloatTensor([min_feat_val_list[i]])
elif perturbation_type == "maximum":
target_dataset[:, i] = torch.FloatTensor([max_feat_val_list[i]])
elif perturbation_type == "plus_std":
target_dataset[:, i] += torch.FloatTensor([std_feat_val_list[i]])
elif perturbation_type == "minus_std":
target_dataset[:, i] -= torch.FloatTensor([std_feat_val_list[i]])

# perturbations_list.append(target_dataset[:, i].numpy())

perturbed_dataset = MOVEDataset(
baseline_dataset.cat_all,
perturbed_con,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)

perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
batch_size=baseline_dataloader.batch_size,
)
dataloaders.append(perturbed_dataloader)
logger.debug("Finished perturb_continuous_data_extended function")

# Plot the perturbations for all features, collapsed in one plot:
# if output_subpath is not None:
# fig = plot_value_distributions(np.array(perturbations_list).transpose())
# fig_path = str(
# output_subpath / f"perturbation_distribution_{target_dataset_name}.png"
# )
# fig.savefig(fig_path)

return dataloaders


def _get_task_type(
task_config: IdentifyAssociationsConfig,
Expand Down

0 comments on commit 8b06298

Please sign in to comment.