Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inform submission about evaluation step #720

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ def update_params(workload: Workload,
optimizer_state: OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState) -> UpdateReturn:
rng: RandomState,
is_eval_step: bool = False) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,13 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool = False) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del eval_results
del is_eval_step

current_model = current_param_container
current_model.train()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del current_params_types
del hyperparameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def pmapped_update_params(workload: spec.Workload,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
optimizer_state: spec.OptimizerState,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
del hyperparameters

def loss_fn(params):
Expand Down Expand Up @@ -85,7 +86,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del hyperparameters
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
Expand Down
18 changes: 14 additions & 4 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def train_once(
global_step = 0
eval_results = []
preemption_count = 0
is_eval_step = False

# Loggers and checkpoint setup.
logging.info('Initializing checkpoint and logger.')
Expand Down Expand Up @@ -325,7 +326,7 @@ def train_once(
while train_state['is_time_remaining'] and \
not goals_reached and \
not train_state['training_complete']:

step_rng = prng.fold_in(rng, global_step)
data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3)

Expand All @@ -338,6 +339,15 @@ def train_once(
hyperparameters,
global_step,
data_select_rng)

# Check if submission is eligible for an untimed eval *before* update_params
if global_step >=1 and \
((train_state['last_step_end_time'] - train_state['last_eval_time']) >=
workload.eval_period_time_sec):
is_eval_step = True
else:
is_eval_step = False

try:
with profiler.profile('Update parameters'):
optimizer_state, model_params, model_state = update_params(
Expand All @@ -351,7 +361,8 @@ def train_once(
optimizer_state=optimizer_state,
eval_results=eval_results,
global_step=global_step,
rng=update_rng)
rng=update_rng,
is_eval_step=is_eval_step)
except spec.TrainingCompleteError:
train_state['training_complete'] = True
global_step += 1
Expand All @@ -369,8 +380,7 @@ def train_once(
train_state['is_time_remaining'] = (
train_state['accumulated_submission_time'] < max_allowed_runtime_sec)
# Check if submission is eligible for an untimed eval.
if ((train_step_end_time - train_state['last_eval_time']) >=
workload.eval_period_time_sec or train_state['training_complete']):
if is_eval_step or train_state['training_complete']:
with profiler.profile('Evaluation'):
del batch
_reset_cuda_mem()
Expand Down
3 changes: 2 additions & 1 deletion submissions/template/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def update_params(workload: spec.Workload,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
rng: spec.RandomState,
is_eval_step: bool) -> spec.UpdateReturn:
"""
Returns:
(new_optimizer_state, update_fn)
Expand Down
Loading