diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 285983957..05aa9a3de 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -424,7 +424,8 @@ def update_params(workload: Workload, optimizer_state: OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: RandomState) -> UpdateReturn: + rng: RandomState, + is_eval_step: bool = False) -> UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" pass diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 099613fcf..60712f20b 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -262,7 +262,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index ef0c11c0d..7d978110e 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -262,7 +262,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index 01cffc52e..f6c6c3164 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -234,11 +234,13 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool = False) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type del eval_results + del is_eval_step current_model = current_param_container current_model.train() diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 530dd3acf..01ce687d1 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -234,7 +234,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index be8b2f7e5..6baa79b87 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -274,7 +274,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 6c859b8dd..ba2cfbee2 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -274,7 +274,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index 57da48167..b31a0fbb1 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -246,7 +246,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index ef6e84c94..b07610b75 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -246,7 +246,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 2971efe9a..ad2cd6e2d 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -120,7 +120,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 358c6bffc..fac876037 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -63,7 +63,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 896609d51..13852cf52 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -51,7 +51,8 @@ def pmapped_update_params(workload: spec.Workload, hyperparameters: spec.Hyperparameters, batch: Dict[str, spec.Tensor], optimizer_state: spec.OptimizerState, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: del hyperparameters def loss_fn(params): @@ -85,7 +86,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index f1601e606..d43d72e87 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -42,7 +42,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del hyperparameters del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 2dd85c29b..f8167822c 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -120,7 +120,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index e6fef17dc..2dc33dd32 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -200,7 +200,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 11212c1a0..28801b9a7 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -120,7 +120,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 75a4abbef..7c9ec9af3 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -61,7 +61,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 27d635ee9..3ff03acce 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -128,7 +128,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7d0d8763e..7139abe8f 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -199,7 +199,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 4139ebcf6..3102de778 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -154,7 +154,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index b7d87924d..cab0618c0 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -77,7 +77,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 099613fcf..60712f20b 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -262,7 +262,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index 01cffc52e..74224ec1a 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -234,7 +234,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 35cebba1f..361acaea9 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -154,7 +154,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 45feb8645..17bae4c35 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -77,7 +77,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 85b3d7441..96cf04abb 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -207,7 +207,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 2cab75972..ec5595209 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -141,7 +141,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 9c6b66b7f..d5e174057 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -123,7 +123,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 2a641b520..11c4e198f 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -79,7 +79,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index f9e40212b..c68c61bb8 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -22,7 +22,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/submission_runner.py b/submission_runner.py index ff290079b..cd1125d1d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -282,6 +282,7 @@ def train_once( global_step = 0 eval_results = [] preemption_count = 0 + is_eval_step = False # Loggers and checkpoint setup. logging.info('Initializing checkpoint and logger.') @@ -325,7 +326,7 @@ def train_once( while train_state['is_time_remaining'] and \ not goals_reached and \ not train_state['training_complete']: - + step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) @@ -338,6 +339,15 @@ def train_once( hyperparameters, global_step, data_select_rng) + + # Check if submission is eligible for an untimed eval *before* update_params + if global_step >=1 and \ + ((train_state['last_step_end_time'] - train_state['last_eval_time']) >= + workload.eval_period_time_sec): + is_eval_step = True + else: + is_eval_step = False + try: with profiler.profile('Update parameters'): optimizer_state, model_params, model_state = update_params( @@ -351,7 +361,8 @@ def train_once( optimizer_state=optimizer_state, eval_results=eval_results, global_step=global_step, - rng=update_rng) + rng=update_rng, + is_eval_step=is_eval_step) except spec.TrainingCompleteError: train_state['training_complete'] = True global_step += 1 @@ -369,8 +380,7 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if is_eval_step or train_state['training_complete']: with profiler.profile('Evaluation'): del batch _reset_cuda_mem() diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 0448a46ed..d84a6cfdb 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -32,7 +32,8 @@ def update_params(workload: spec.Workload, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + is_eval_step: bool) -> spec.UpdateReturn: """ Returns: (new_optimizer_state, update_fn)