From e6c2106c2460d0149235dd4eccfd4017b0952734 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Sep 2024 15:30:15 +0200 Subject: [PATCH 1/8] added prepare_for_eval, eval only if is_time_remaining --- algorithmic_efficiency/spec.py | 14 +- submission_runner.py | 203 ++++++++++++++++------------- submissions/template/submission.py | 20 +++ 3 files changed, 149 insertions(+), 88 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 285983957..792093a2e 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -406,7 +406,19 @@ def init_optimizer_state(workload: Workload, RandomState ], UpdateReturn] - +PrepareForEvalFn = Callable[[ + Workload, + ParameterContainer, + ParameterTypeTree, + ModelAuxiliaryState, + Hyperparameters, + LossType, + OptimizerState, + List[Tuple[int, float]], + int, + RandomState +], + UpdateReturn] # Each call to this function is considered a "step". # Can raise a TrainingCompleteError if it believes it has achieved the goal and diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..5df1f05ff 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -200,6 +200,7 @@ def train_once( init_optimizer_state: spec.InitOptimizerFn, update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, + prepare_for_eval: spec.PrepareForEvalFn, hyperparameters: Optional[spec.Hyperparameters], rng_seed: int, rng: spec.RandomState, @@ -335,7 +336,7 @@ def train_once( not train_state['training_complete']: step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) + data_select_rng, update_rng, prep_eval_rng, eval_rng = prng.split(step_rng, 4) with profiler.profile('Data selection'): batch = data_selection(workload, @@ -370,101 +371,128 @@ def train_once( train_state['accumulated_submission_time'] += ( train_step_end_time - train_state['last_step_end_time']) - # Use 3x the runtime budget for the self-tuning ruleset. - max_allowed_runtime_sec = ( - workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' - else 3 * workload.max_allowed_runtime_sec) - train_state['is_time_remaining'] = ( - train_state['accumulated_submission_time'] < max_allowed_runtime_sec) + # Check if submission is eligible for an untimed eval. if ((train_step_end_time - train_state['last_eval_time']) >= workload.eval_period_time_sec or train_state['training_complete']): - with profiler.profile('Evaluation'): + + # Prepare for evaluation (timed). + with profiler.profile('Prepare for eval'): del batch - _reset_cuda_mem() - - try: - eval_start_time = get_time() - latest_eval_result = workload.eval_model(global_eval_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir, - global_step) - # Check if targets reached. - # Note that this is one of the stopping conditions for the length of - # a training run. To score the run we only consider the time - # to validation target retrospectively. - train_state['validation_goal_reached'] = ( - workload.has_reached_validation_target(latest_eval_result) or - train_state['validation_goal_reached']) - train_state['test_goal_reached'] = ( - workload.has_reached_test_target(latest_eval_result) or - train_state['test_goal_reached']) - goals_reached = ( - train_state['validation_goal_reached'] and - train_state['test_goal_reached']) - # Save last eval time. - eval_end_time = get_time() - train_state['last_eval_time'] = eval_end_time - - # Accumulate eval time. - train_state[ - 'accumulated_eval_time'] += eval_end_time - eval_start_time - - # Add times to eval results for logging. - latest_eval_result['score'] = ( - train_state['accumulated_submission_time']) - latest_eval_result[ - 'total_duration'] = eval_end_time - global_start_time - latest_eval_result['accumulated_submission_time'] = train_state[ - 'accumulated_submission_time'] - latest_eval_result['accumulated_eval_time'] = train_state[ - 'accumulated_eval_time'] - latest_eval_result['accumulated_logging_time'] = train_state[ - 'accumulated_logging_time'] - time_since_start = latest_eval_result['total_duration'] - logging.info(f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}') - eval_results.append((global_step, latest_eval_result)) - - logging_start_time = get_time() - - if log_dir is not None and RANK == 0: - metrics_logger.append_scalar_metrics( - latest_eval_result, - global_step=global_step, - preemption_count=preemption_count, - is_eval=True, - ) - if save_checkpoints: - checkpoint_utils.save_checkpoint( - framework=FLAGS.framework, - optimizer_state=optimizer_state, - model_params=model_params, - model_state=model_state, - train_state=train_state, - eval_results=eval_results, - global_step=global_step, - preemption_count=preemption_count, - checkpoint_dir=log_dir, - save_intermediate_checkpoints=FLAGS - .save_intermediate_checkpoints) + prepare_for_eval_start_time = get_time() + optimizer_state, model_params, model_state = prepare_for_eval( + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=prep_eval_rng) + prepare_for_eval_end_time = get_time() + + # Update sumbission time. + train_state['accumulated_submission_time'] += ( + prepare_for_eval_end_time - prepare_for_eval_start_time) + + # Check if time is remaining, + # use 3x the runtime budget for the self-tuning ruleset. + max_allowed_runtime_sec = ( + workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' + else 3 * workload.max_allowed_runtime_sec) + train_state['is_time_remaining'] = ( + train_state['accumulated_submission_time'] < max_allowed_runtime_sec) - logging_end_time = get_time() - train_state['accumulated_logging_time'] += ( - logging_end_time - logging_start_time) + # Eval if time is remaining (untimed). + if train_state['is_time_remaining']: + with profiler.profile('Evaluation'): _reset_cuda_mem() - except RuntimeError as e: - logging.exception(f'Eval step {global_step} error.\n') - if 'out of memory' in str(e): - logging.warning('Error: GPU out of memory during eval during step ' - f'{global_step}, error : {str(e)}.') + try: + eval_start_time = get_time() + latest_eval_result = workload.eval_model(global_eval_batch_size, + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir, + global_step) + # Check if targets reached. + # Note that this is one of the stopping conditions for the length of + # a training run. To score the run we only consider the time + # to validation target retrospectively. + train_state['validation_goal_reached'] = ( + workload.has_reached_validation_target(latest_eval_result) or + train_state['validation_goal_reached']) + train_state['test_goal_reached'] = ( + workload.has_reached_test_target(latest_eval_result) or + train_state['test_goal_reached']) + goals_reached = ( + train_state['validation_goal_reached'] and + train_state['test_goal_reached']) + # Save last eval time. + eval_end_time = get_time() + train_state['last_eval_time'] = eval_end_time + + # Accumulate eval time. + train_state[ + 'accumulated_eval_time'] += eval_end_time - eval_start_time + + # Add times to eval results for logging. + latest_eval_result['score'] = ( + train_state['accumulated_submission_time']) + latest_eval_result[ + 'total_duration'] = eval_end_time - global_start_time + latest_eval_result['accumulated_submission_time'] = train_state[ + 'accumulated_submission_time'] + latest_eval_result['accumulated_eval_time'] = train_state[ + 'accumulated_eval_time'] + latest_eval_result['accumulated_logging_time'] = train_state[ + 'accumulated_logging_time'] + time_since_start = latest_eval_result['total_duration'] + logging.info(f'Time since start: {time_since_start:.2f}s, ' + f'\tStep: {global_step}, \t{latest_eval_result}') + eval_results.append((global_step, latest_eval_result)) + + logging_start_time = get_time() + + if log_dir is not None and RANK == 0: + metrics_logger.append_scalar_metrics( + latest_eval_result, + global_step=global_step, + preemption_count=preemption_count, + is_eval=True, + ) + if save_checkpoints: + checkpoint_utils.save_checkpoint( + framework=FLAGS.framework, + optimizer_state=optimizer_state, + model_params=model_params, + model_state=model_state, + train_state=train_state, + eval_results=eval_results, + global_step=global_step, + preemption_count=preemption_count, + checkpoint_dir=log_dir, + save_intermediate_checkpoints=FLAGS + .save_intermediate_checkpoints) + + logging_end_time = get_time() + train_state['accumulated_logging_time'] += ( + logging_end_time - logging_start_time) + _reset_cuda_mem() + except RuntimeError as e: + logging.exception(f'Eval step {global_step} error.\n') + if 'out of memory' in str(e): + logging.warning('Error: GPU out of memory during eval during step ' + f'{global_step}, error : {str(e)}.') + _reset_cuda_mem() + train_state['last_step_end_time'] = get_time() metrics = {'eval_results': eval_results, 'global_step': global_step} @@ -518,6 +546,7 @@ def score_submission_on_workload(workload: spec.Workload, init_optimizer_state = submission_module.init_optimizer_state update_params = submission_module.update_params data_selection = submission_module.data_selection + prepare_for_eval = submission_module.prepare_for_eval try: global_batch_size = submission_module.get_batch_size(workload_name) except ValueError: @@ -589,7 +618,7 @@ def score_submission_on_workload(workload: spec.Workload, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, - update_params, data_selection, + update_params, data_selection, prepare_for_eval, hyperparameters, rng_seed, rng, diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 5ef195db5..848d8af44 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -42,6 +42,26 @@ def update_params(workload: spec.Workload, pass +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + # batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """ + Returns: + new_optimizer_state + new_params + new_model_state + """ + pass + + def get_batch_size(workload_name): """ Gets batch size for workload. From 8bad99d663f34ce4b6b6c4a2a40b828e19fc3a5b Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sat, 14 Sep 2024 18:11:43 +0200 Subject: [PATCH 2/8] added prepare_for_eval to all submissions --- .../external_tuning/jax_nadamw_full_budget.py | 21 ++++++++++++++++ .../jax_nadamw_target_setting.py | 21 ++++++++++++++++ .../pytorch_nadamw_full_budget.py | 21 ++++++++++++++++ .../pytorch_nadamw_target_setting.py | 21 ++++++++++++++++ .../self_tuning/jax_nadamw_full_budget.py | 21 ++++++++++++++++ .../self_tuning/jax_nadamw_target_setting.py | 21 ++++++++++++++++ .../self_tuning/pytorch_nadamw_full_budget.py | 21 ++++++++++++++++ .../pytorch_nadamw_target_setting.py | 21 ++++++++++++++++ .../cifar/cifar_jax/submission.py | 25 +++++++++++++++++-- .../cifar/cifar_pytorch/submission.py | 21 ++++++++++++++++ .../mnist/mnist_jax/submission.py | 21 ++++++++++++++++ .../mnist/mnist_pytorch/submission.py | 21 ++++++++++++++++ .../adafactor/jax/submission.py | 21 ++++++++++++++++ .../adafactor/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/adamw/jax/submission.py | 21 ++++++++++++++++ .../adamw/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/lamb/jax/submission.py | 21 ++++++++++++++++ .../lamb/pytorch/submission.py | 21 ++++++++++++++++ .../momentum/jax/submission.py | 21 ++++++++++++++++ .../momentum/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/nadamw/jax/submission.py | 21 ++++++++++++++++ .../nadamw/pytorch/submission.py | 21 ++++++++++++++++ .../nesterov/jax/submission.py | 21 ++++++++++++++++ .../nesterov/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/sam/jax/submission.py | 21 ++++++++++++++++ .../paper_baselines/sam/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/shampoo/jax/submission.py | 21 ++++++++++++++++ .../jax_submission_base.py | 21 ++++++++++++++++ .../pytorch_submission_base.py | 21 ++++++++++++++++ 29 files changed, 611 insertions(+), 2 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 98193f01f..5f203c5c6 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -299,6 +299,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 66fdc4ebb..32f4e830e 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -299,6 +299,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index ebc49d428..ba56cd99f 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -301,6 +301,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 524bc20af..e2c44d9c1 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -301,6 +301,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4f53afb56..502b7e5b4 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -314,6 +314,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60a1f784d..8bc2eed95 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -314,6 +314,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index f8e87ec2a..bbf548ccb 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -316,6 +316,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 1de26417f..992f769f3 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -316,6 +316,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 2971efe9a..b2256fc5a 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -108,8 +108,6 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. def update_params(workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, @@ -134,6 +132,29 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. def data_selection(workload: spec.Workload, input_queue: Iterator[Dict[str, spec.Tensor]], optimizer_state: spec.OptimizerState, diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 358c6bffc..b55c31afc 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -96,6 +96,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 896609d51..f09886215 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -106,6 +106,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), updated_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index f1601e606..8b5151c77 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -72,6 +72,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 2dd85c29b..ed2ee371f 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -157,6 +157,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index e6fef17dc..5f6540020 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -265,6 +265,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 80a963600..5d2107ba6 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -157,6 +157,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 32353e5b4..2b42bb5a4 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -125,6 +125,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 27d635ee9..e08d5b433 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -165,6 +165,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7d0d8763e..da5865087 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -258,6 +258,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index cccb3c1b5..1ab362dd6 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -191,6 +191,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index ec5c0b31c..999321bd5 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -144,6 +144,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 98193f01f..5f203c5c6 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -299,6 +299,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index ebc49d428..ba56cd99f 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -301,6 +301,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index f3b0aeed4..20109a9e3 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -191,6 +191,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index fe9154934..b4b8b77af 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -144,6 +144,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 85b3d7441..9f12c4f3f 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -244,6 +244,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 2cab75972..cf5e49f4f 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -216,6 +216,27 @@ def _loss_fn(params, update_batch_norm=True): return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 9c6b66b7f..b596f0bdc 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -160,6 +160,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 2a641b520..31e8a8850 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -109,3 +109,24 @@ def update_params(workload: spec.Workload, 'grad_norm': grad_norm[0], }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index f9e40212b..549d2dc58 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -89,3 +89,24 @@ def update_params(workload: spec.Workload, grad_norm.item()) return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) From 1c7d51c0eb2cf64295f030b6ef0566bcd24b01cf Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Sep 2024 11:08:48 +0200 Subject: [PATCH 3/8] fix formatting --- submission_runner.py | 25 ++++++++++++++----------- submissions/template/submission.py | 1 - 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 5df1f05ff..a711be9ac 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -336,7 +336,8 @@ def train_once( not train_state['training_complete']: step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, prep_eval_rng, eval_rng = prng.split(step_rng, 4) + data_select_rng, update_rng, prep_eval_rng, eval_rng = \ + prng.split(step_rng, 4) with profiler.profile('Data selection'): batch = data_selection(workload, @@ -414,12 +415,12 @@ def train_once( try: eval_start_time = get_time() latest_eval_result = workload.eval_model(global_eval_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir, - global_step) + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir, + global_step) # Check if targets reached. # Note that this is one of the stopping conditions for the length of # a training run. To score the run we only consider the time @@ -454,7 +455,7 @@ def train_once( 'accumulated_logging_time'] time_since_start = latest_eval_result['total_duration'] logging.info(f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}') + f'\tStep: {global_step}, \t{latest_eval_result}') eval_results.append((global_step, latest_eval_result)) logging_start_time = get_time() @@ -489,8 +490,9 @@ def train_once( except RuntimeError as e: logging.exception(f'Eval step {global_step} error.\n') if 'out of memory' in str(e): - logging.warning('Error: GPU out of memory during eval during step ' - f'{global_step}, error : {str(e)}.') + logging.warning( + 'Error: GPU out of memory during eval during step ' + f'{global_step}, error : {str(e)}.') _reset_cuda_mem() train_state['last_step_end_time'] = get_time() @@ -618,7 +620,8 @@ def score_submission_on_workload(workload: spec.Workload, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, - update_params, data_selection, prepare_for_eval, + update_params, data_selection, + prepare_for_eval, hyperparameters, rng_seed, rng, diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 848d8af44..445e1f7cd 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -47,7 +47,6 @@ def prepare_for_eval(workload: spec.Workload, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparameters, - # batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], From 21a580b56c1f19cff11b13b62d4fceb1dc003f29 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Sep 2024 11:18:03 +0200 Subject: [PATCH 4/8] fix formatting --- algorithmic_efficiency/spec.py | 3 ++- submission_runner.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 792093a2e..25bd7b6d0 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -418,7 +418,8 @@ def init_optimizer_state(workload: Workload, int, RandomState ], - UpdateReturn] + UpdateReturn] + # Each call to this function is considered a "step". # Can raise a TrainingCompleteError if it believes it has achieved the goal and diff --git a/submission_runner.py b/submission_runner.py index a711be9ac..632cb450b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -620,7 +620,7 @@ def score_submission_on_workload(workload: spec.Workload, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, - update_params, data_selection, + update_params, data_selection, prepare_for_eval, hyperparameters, rng_seed, From 420b583f8bd60ca13b6b7cf9a7d0b8211d5c904b Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Sep 2024 12:48:13 +0200 Subject: [PATCH 5/8] updated documentation --- DOCUMENTATION.md | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 607f47ead..586e03d8c 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -80,7 +80,7 @@ In principle, submissions are allowed to use the available hardware systems in a Submissions provide a [per-workload batch size](#batch-size-getter) to use. Specification of the batch size for each workload is necessary to avoid running out of memory for different workloads. Therefore, submitters can determine this batch size in advance and specify it as part of the submission. Submitters may also provide per-workload batch sizes for all [randomized workloads](#randomized-workloads). If no such batch size is provided for a randomized workload, by default, submissions will then use the batch size of the most similar [fixed workload](#fixed-workloads) (for example, if there is an ImageNet fixed workload and also a randomized workload with a similarly sized model on similarly sized images, the ImageNet batch size will be used for held-out workloads generated from this randomized workload). Note that submitters are *not* allowed to modify the *evaluation batch size*, which is set by the benchmarking codebase. However, you can file an issue if you believe that the evaluation batch size of a particular workload is set inappropriately. The working group will review this request and consider adjusting the evaluation batch size in the benchmarking codebase, thus affecting all submitters equally. -The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code. +The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, *prepare for evaluation function*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code. ##### Fixed functions @@ -218,9 +218,35 @@ def update_params( - Cannot modify the given hyperparameters in a workload-conditional way (please see the [Valid submission](#valid-submissions) section). This rule is intended to prohibit circumventing the tuning rules by looking up a pre-tuned optimal set of hyperparameters for each workload. It is not intended to prohibit line searches and other similar techniques. - The fixed `init_model_fn` can optionally be called during training, for example, to reinitialize the model after a failed training effort. - Cannot replace the model parameters with pre-trained ones. -- This API supports Polyak averaging and similar methods that implement moving averages of model parameters. - Batch norm should work here because the `model_fn` will return updated batch norm moving averages when it is told to with `update_batch_norm`. + +###### Prepare for evaluation function + +```python +def prepare_for_eval( + workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState +) -> (updated_optimizer_state, updated_variables, updated_model_state) +``` + +- Arguments are the same of `update_param`, with the only exception of `batch`. +- This function is called when a submission is deemed eligible for an evaluation (see [Evluation during training](#evaluation-during-training) section). + - The call to `prepare_for_eval` is timed and its runtime accumulates to the overall submission time. + - The returned model parameters are evaluated on the validation and test sets, provided that the accumulated submission time does not exceed the maximum runtime after this function call. +- This API supports Polyak averaging and similar methods that implement moving averages of model parameters. +- Allowed to update model state and model parameters. +- Allowed to update state for the optimizer. +- Cannot replace the model parameters with pre-trained ones. + ###### Data selection ```python @@ -250,7 +276,8 @@ def data_selection( In general, with noisy, non-deterministic training, evaluation frequency can affect training time measurements as more "bites of the apple" potentially allows the training code to exploit instability. We also want to discourage submissions from complicated and unrealistic logic that attempts to guess when training is close to complete and increases the evaluation rate, while not producing a well-sampled training curve at the start of training. Simply allowing submissions complete freedom over evaluation frequency encourages competitors to work to minimize the number of evaluations, which distracts from the primary goal of finding better training algorithms. -Submissions are eligible for an untimed eval every `eval_period` seconds, run as soon as the current call of `update_params` completes. Any additional evaluations performed by the submission code count against the runtime for scoring. The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval and, if so, pausing the clock and running an eval. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs. +Submissions are eligible for an untimed eval every `eval_period` seconds. Before proceeding to evaluation, the submission can prepare the model through a call to `prepare_for_eval`, effectively modifying the model parameters and state as well as the the optimizer state. Any additional evaluations performed by the submission code count against the runtime for scoring. +The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval, if so, the submission is given the possibility to prepare for evaluation (through a timed call to `prepare_for_eval`). If the accumulated runtime does not exceed the maximum allowed runtime after the preparation step, the clock is paused, and the submission is evaluated. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs. #### Valid submissions From d9c4ee9d3a85f55e069db21b39feaf216ee9d42d Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 18 Oct 2024 17:19:41 +0200 Subject: [PATCH 6/8] add prepare_for_eval to spec.py --- algorithmic_efficiency/spec.py | 43 ++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 25bd7b6d0..b8be5fcaa 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -406,19 +406,6 @@ def init_optimizer_state(workload: Workload, RandomState ], UpdateReturn] -PrepareForEvalFn = Callable[[ - Workload, - ParameterContainer, - ParameterTypeTree, - ModelAuxiliaryState, - Hyperparameters, - LossType, - OptimizerState, - List[Tuple[int, float]], - int, - RandomState -], - UpdateReturn] # Each call to this function is considered a "step". @@ -442,6 +429,36 @@ def update_params(workload: Workload, pass +PrepareForEvalFn = Callable[[ + Workload, + ParameterContainer, + ParameterTypeTree, + ModelAuxiliaryState, + Hyperparameters, + LossType, + OptimizerState, + List[Tuple[int, float]], + int, + RandomState +], + UpdateReturn] + + +# Prepare model and optimizer for evaluation. +def prepare_for_eval(workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState) -> UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + pass + + DataSelectionFn = Callable[[ Workload, Iterator[Dict[str, Any]], From 9caedc5570550708aba7d2695e15b2480ca7cf0f Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 21 Oct 2024 11:48:35 +0200 Subject: [PATCH 7/8] make prepare_for_eval backward compatible --- submission_runner.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 632cb450b..3ef30ffba 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -378,25 +378,27 @@ def train_once( workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). - with profiler.profile('Prepare for eval'): - del batch - prepare_for_eval_start_time = get_time() - optimizer_state, model_params, model_state = prepare_for_eval( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types, - model_state=model_state, - hyperparameters=hyperparameters, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - eval_results=eval_results, - global_step=global_step, - rng=prep_eval_rng) - prepare_for_eval_end_time = get_time() - - # Update sumbission time. - train_state['accumulated_submission_time'] += ( - prepare_for_eval_end_time - prepare_for_eval_start_time) + if prepare_for_eval is not None: + + with profiler.profile('Prepare for eval'): + del batch + prepare_for_eval_start_time = get_time() + optimizer_state, model_params, model_state = prepare_for_eval( + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=prep_eval_rng) + prepare_for_eval_end_time = get_time() + + # Update sumbission time. + train_state['accumulated_submission_time'] += ( + prepare_for_eval_end_time - prepare_for_eval_start_time) # Check if time is remaining, # use 3x the runtime budget for the self-tuning ruleset. @@ -548,7 +550,7 @@ def score_submission_on_workload(workload: spec.Workload, init_optimizer_state = submission_module.init_optimizer_state update_params = submission_module.update_params data_selection = submission_module.data_selection - prepare_for_eval = submission_module.prepare_for_eval + prepare_for_eval = getattr(submission_module, 'prepare_for_eval', None) try: global_batch_size = submission_module.get_batch_size(workload_name) except ValueError: From 4d74d2ccee73ae6096a9fceff6a7b60c80f8f5a7 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 21 Oct 2024 12:00:29 +0200 Subject: [PATCH 8/8] optional prepare_for_eval arg --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 3ef30ffba..c396cb027 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -200,7 +200,7 @@ def train_once( init_optimizer_state: spec.InitOptimizerFn, update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, - prepare_for_eval: spec.PrepareForEvalFn, + prepare_for_eval: Optional[spec.PrepareForEvalFn], hyperparameters: Optional[spec.Hyperparameters], rng_seed: int, rng: spec.RandomState,