Skip to content

Commit

Permalink
🔥 remove duplicated functionality
Browse files Browse the repository at this point in the history
- use only one version in extra module
  • Loading branch information
Henry committed Jul 4, 2024
1 parent 6cfd1f8 commit 8277891
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 119 deletions.
10 changes: 1 addition & 9 deletions src/move/data/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
ContinuousPerturbationType = Literal["minimum", "maximum", "plus_std", "minus_std"]


# Also in analyze_latent.py
def perturb_continuous_data_one(
baseline_dataloader: DataLoader,
con_dataset_names: list[str],
Expand All @@ -35,7 +34,7 @@ def perturb_continuous_data_one(
baseline_dataloader: Baseline dataloader
con_dataset_names: List of continuous dataset names
target_dataset_name: Target continuous dataset to perturb
target_value: Target value. In analyze_latent, it will be 0
target_value: Target value.
Returns:
One dataloader, with the ith dataset perturbed
Expand All @@ -49,11 +48,7 @@ def perturb_continuous_data_one(
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx : target_idx + 2])

# num_features = baseline_dataset.con_shapes[target_idx]
# dataloaders = []
i = index_pert_feat
# Instead of the loop, we do it only for one
# for i in range(num_features):
perturbed_con = baseline_dataset.con_all.clone()
target_dataset = perturbed_con[:, slice_]
target_dataset[:, i] = torch.FloatTensor([target_value])
Expand Down Expand Up @@ -103,11 +98,8 @@ def perturb_categorical_data_one(
slice_ = slice(*splits[target_idx : target_idx + 2])

target_shape = baseline_dataset.cat_shapes[target_idx]
# num_features = target_shape[0]

i = index_pert_feat
# dataloaders = []
# for i in range(num_features):
perturbed_cat = baseline_dataset.cat_all.clone()
target_dataset = perturbed_cat[:, slice_].view(
baseline_dataset.num_samples, *target_shape
Expand Down
114 changes: 4 additions & 110 deletions src/move/tasks/analyze_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pandas as pd
import torch
from sklearn.base import TransformerMixin
from torch.utils.data import DataLoader

import move.visualization as viz
from move.analysis.metrics import (
Expand All @@ -24,119 +23,14 @@
from move.core.typing import FloatArray
from move.data import io
from move.data.dataloaders import MOVEDataset, make_dataloader
from move.data.perturbations import (
perturb_categorical_data_one,
perturb_continuous_data_one,
)
from move.data.preprocessing import one_hot_encode_single
from move.models.vae import VAE
from move.training.training_loop import TrainingLoopOutput

# Define perturb_continuous_data_one (not extended)


def perturb_continuous_data_one(
baseline_dataloader: DataLoader,
con_dataset_names: list[str],
target_dataset_name: str,
target_value: float,
index_pert_feat: int, # Index of the datasetto perturb
) -> DataLoader: # change list(DataLoader) to just one DataLoader
"""Add perturbations to continuous data. For each feature in the target
dataset, change its value to target.
Args:
baseline_dataloader: Baseline dataloader
con_dataset_names: List of continuous dataset names
target_dataset_name: Target continuous dataset to perturb
target_value: Target value
Returns:
One dataloader, with the ith dataset perturbed
"""

baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.con_shapes is not None
assert baseline_dataset.con_all is not None

target_idx = con_dataset_names.index(target_dataset_name)
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx : target_idx + 2])

# num_features = baseline_dataset.con_shapes[target_idx]
# dataloaders = []
i = index_pert_feat
# Instead of the loop, we do it only for one
# for i in range(num_features):
perturbed_con = baseline_dataset.con_all.clone()
target_dataset = perturbed_con[:, slice_]
target_dataset[:, i] = torch.FloatTensor([target_value])
perturbed_dataset = MOVEDataset(
baseline_dataset.cat_all,
perturbed_con,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)
perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
batch_size=baseline_dataloader.batch_size,
)

return perturbed_dataloader


def perturb_categorical_data_one(
baseline_dataloader: DataLoader,
cat_dataset_names: list[str],
target_dataset_name: str,
target_value: np.ndarray,
index_pert_feat: int,
) -> DataLoader:
"""Add perturbations to categorical data. For each feature in the target
dataset, change its value to target.
Args:
baseline_dataloader: Baseline dataloader
cat_dataset_names: List of categorical dataset names
target_dataset_name: Target categorical dataset to perturb
target_value: Target value
Returns:
List of dataloaders containing all perturbed datasets
"""

baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.cat_shapes is not None
assert baseline_dataset.cat_all is not None

target_idx = cat_dataset_names.index(target_dataset_name)
splits = np.cumsum(
[0] + [int.__mul__(*shape) for shape in baseline_dataset.cat_shapes]
)
slice_ = slice(*splits[target_idx : target_idx + 2])

target_shape = baseline_dataset.cat_shapes[target_idx]
# num_features = target_shape[0] # CHANGE

i = index_pert_feat
# dataloaders = []
# for i in range(num_features):
perturbed_cat = baseline_dataset.cat_all.clone()
target_dataset = perturbed_cat[:, slice_].view(
baseline_dataset.num_samples, *target_shape
)
target_dataset[:, i, :] = torch.FloatTensor(target_value)
perturbed_dataset = MOVEDataset(
perturbed_cat,
baseline_dataset.con_all,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)
perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
batch_size=baseline_dataloader.batch_size,
)

return perturbed_dataloader


def find_feature_values(
feature_name: str,
Expand Down

0 comments on commit 8277891

Please sign in to comment.