From fb617a4827d3aeb5ba642ce110283a8d666591d2 Mon Sep 17 00:00:00 2001 From: EdenWuyifan Date: Tue, 7 May 2024 11:44:39 -0400 Subject: [PATCH] add ensembler supports --- alpha_automl/hyperparameter_tuning/smac.py | 27 +++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/alpha_automl/hyperparameter_tuning/smac.py b/alpha_automl/hyperparameter_tuning/smac.py index b558ce3..5e62bb5 100644 --- a/alpha_automl/hyperparameter_tuning/smac.py +++ b/alpha_automl/hyperparameter_tuning/smac.py @@ -48,8 +48,12 @@ def gen_pipeline(config, pipeline): transformers.append((trans_name, trans_obj, trans_index)) step_obj.__dict__['transformers'] = transformers new_pipeline.steps.append([step_name, create_object(step_name, step_obj.__dict__)]) + elif step_type == 'CLASSIFICATION_SINGLE_ENSEMBLER' or step_type == 'REGRESSION_SINGLE_ENSEMBLER': + estimator = step_obj.estimator + primitive_object = create_object(step_name, {'estimator': estimator}) + new_pipeline.steps.append([step_name, primitive_object]) elif step_type == 'CLASSIFICATION_MULTI_ENSEMBLER' or step_type == 'REGRESSION_MULTI_ENSEMBLER': - estimators = extract_estimators(pipeline, PRIMITIVE_TYPES) + estimators = extract_estimators_smac(step_obj, PRIMITIVE_TYPES) primitive_object = create_object(step_name, {'estimators': estimators}) new_pipeline.steps.append([step_name, primitive_object]) else: @@ -58,6 +62,17 @@ def gen_pipeline(config, pipeline): return new_pipeline +def extract_estimators_smac(step_obj, config): + new_estimators = [] + estimators = step_obj.estimators + while estimators: + estimator_name, estimator_obj = estimators.pop() + estimator_name_lookup, estimator_name_counter = estimator_name.split('-') + new_estimators.append((estimator_name, create_object(estimator_name_lookup, get_primitive_params(config, estimator_name_lookup)))) + + return new_estimators + + def get_primitive_params(config, step_name): params = list(SMAC_DICT[step_name].keys()) class_params = {} @@ -80,6 +95,16 @@ def gen_configspace(pipeline): trans_prim_name = trans_name.split('-')[0] params = SMAC_DICT[trans_prim_name] configspace.add_hyperparameters(cast_primitive(params)) + # elif step_type == 'CLASSIFICATION_SINGLE_ENSEMBLER' or step_type == 'REGRESSION_SINGLE_ENSEMBLER': + # estimator_obj = prim_obj.estimator + # for smac_name, smac_params in SMAC_DICT.items(): + # if estimator_obj.__class__.__name__ in smac_name: + # configspace.add_hyperparameters(cast_primitive(smac_params)) + elif step_type == 'CLASSIFICATION_MULTI_ENSEMBLER' or step_type == 'REGRESSION_MULTI_ENSEMBLER': + for estimator_name, _ in prim_obj.estimators: + estimator_name_lookup, _ = estimator_name.split('-') + params = SMAC_DICT[estimator_name_lookup] + configspace.add_hyperparameters(cast_primitive(params)) except Exception as e: logger.critical(f'[SMAC] {str(e)}') return configspace