Skip to content

Commit

Permalink
fix sample_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Jun 24, 2024
1 parent b8c07ab commit 4b050a4
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
6 changes: 3 additions & 3 deletions alpha_automl/automl_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from alpha_automl.scorer import make_splitter, score_pipeline
from alpha_automl.utils import sample_dataset, is_equal_splitting
from alpha_automl.pipeline_synthesis.setup_search import search_pipelines as search_pipelines_proc

from alpha_automl.pipeline_search.agent_lab import read_result_to_pipeline

USE_AUTOMATIC_GRAMMAR = False
PRIORITIZE_PRIMITIVES = False
EXCLUDE_PRIMITIVES = []
INCLUDE_PRIMITIVES = []
NEW_PRIMITIVES = {}
SPLITTING_STRATEGY = 'holdout'
SAMPLE_SIZE = 2000
SAMPLE_SIZE = 50000
MAX_RUNNING_PROCESSES = multiprocessing.cpu_count()

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,8 +52,8 @@ def search_pipelines(self, X, y, scoring, splitting_strategy, automl_hyperparams
def _search_pipelines(self, automl_hyperparams):
search_start_time = time.time()
automl_hyperparams = self.check_automl_hyperparams(automl_hyperparams)
metadata = profile_data(self.X)
X, y, is_sample = sample_dataset(self.X, self.y, SAMPLE_SIZE, self.task)
metadata = profile_data(X)
internal_splitting_strategy = make_splitter(SPLITTING_STRATEGY)
self.found_pipelines = 0
need_rescoring = True
Expand Down
9 changes: 3 additions & 6 deletions alpha_automl/data_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@


def profile_data(X):
metadata = {'nonnumeric_columns': {}, 'useless_columns': [], 'missing_values': False, 'numeric_columns': [], 'categorical_columns': []}
metadata = {'nonnumeric_columns': {}, 'useless_columns': [], 'missing_values': X.isnull().values.any(), 'numeric_columns': [], 'categorical_columns': []}
mapping_encoders = {CATEGORICAL_COLUMN: 'CATEGORICAL_ENCODER', DATETIME_COLUMN: 'DATETIME_ENCODER',
TEXT_COLUMN: 'TEXT_ENCODER', IMAGE_COLUMN: 'IMAGE_ENCODER'}

profiled_data = datamart_profiler.process_dataset(X, coverage=False, indexes=False)
profiled_data = datamart_profiler.process_dataset(X.sample(n=100, replace=True, random_state=1), coverage=False, indexes=False)

for index_column, profiled_column in enumerate(profiled_data['columns']):
column_name = profiled_column['name']
Expand All @@ -40,9 +40,6 @@ def profile_data(X):
column_type = mapping_encoders[TEXT_COLUMN]
add_nonnumeric_column(column_type, metadata, index_column, column_name)

if 'missing_values_ratio' in profiled_column:
metadata['missing_values'] = True

metadata['numeric_columns'] = list(X.select_dtypes(include=['int64', 'float64']).columns)
metadata['categorical_columns'] = list(X.select_dtypes(include=['object', 'category']).columns)

Expand Down
4 changes: 3 additions & 1 deletion alpha_automl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def create_object(import_path, class_params=None):


def sample_dataset(X, y, sample_size, task):
original_size = len(X)
original_rows = len(X)
original_cols = len(X.columns)
original_size = original_rows * original_cols
shuffle = True
if task == 'TIME_SERIES_FORECAST':
shuffle = False
Expand Down
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def test_sample_dataset():
dataset = pd.read_csv(dataset_path)
X = dataset.drop(columns=["rating"])
y = dataset[["rating"]]
sample_size = 10
sample_size = 110 # 11 * 10

actual_X, actual_y, actual_is_sampled = sample_dataset(
X, y, sample_size, "CLASSIFICATION"
)
expected_X_len = sample_size
expected_y_len = sample_size
expected_X_len = sample_size // len(X.columns)
expected_y_len = sample_size // len(X.columns)
expected_is_sampled = True

assert actual_is_sampled == expected_is_sampled
Expand Down

0 comments on commit 4b050a4

Please sign in to comment.