Skip to content

Commit

Permalink
Cosmoflow model updates
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedorowicz1 committed May 31, 2024
1 parent c614cd3 commit 67fa8b8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
8 changes: 5 additions & 3 deletions applications/physics/cosmology/cosmoflow/cosmoflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def construct_cosmoflow_model(parallel_strategy,
learning_rate,
min_distconv_width,
mlperf,
transform_input):
transform_input,
dropout_keep_prob=0.5):

# Construct layer graph
universes = lbann.Input(data_field='samples')
Expand All @@ -23,7 +24,8 @@ def construct_cosmoflow_model(parallel_strategy,
use_bn=use_batchnorm,
bn_statistics_group_size=statistics_group_size,
mlperf=mlperf,
transform_input=transform_input)(universes)
transform_input=transform_input,
dropout_keep_prob=dropout_keep_prob)(universes)
mse = lbann.MeanSquaredError([preds, secrets])
mae = lbann.MeanAbsoluteError([preds, secrets])
obj = lbann.ObjectiveFunction([mse])
Expand Down Expand Up @@ -71,7 +73,7 @@ def construct_cosmoflow_model(parallel_strategy,
# initial_warmup_learning_rate=0,
# warmup_steps=100
# ),
lbann.CallbackProgressBar(newline_interval=1)
lbann.CallbackProgressBar(newline_interval=1, print_mem_usage=True)
]
return lbann.Model(
epochs=num_epochs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def __init__(self,
use_bn=False,
bn_statistics_group_size=None,
mlperf=False,
transform_input=False):
transform_input=False,
dropout_keep_prob=0.5):
"""Initialize CosmFlow.
Args:
Expand All @@ -43,6 +44,8 @@ def __init__(self,
model.
transform_input (bool): Whether or not to apply log1p
transformation to model inputs.
dropout_keep_prob (float): Probability of not zeroing out
activations in dropout layers. Setting to 1 disables dropout.
"""

CosmoFlow.global_count += 1
Expand All @@ -53,6 +56,7 @@ def __init__(self,
self.use_bn = use_bn
self.mlperf = mlperf
self.transform_input = transform_input
self.dropout_keep_prob = dropout_keep_prob

if self.mlperf:
base_channels = 32
Expand Down Expand Up @@ -144,8 +148,11 @@ def create_act(x, i):
self.name, i, self.instance))

def create_dropout(x, i):
if self.dropout_keep_prob == 1:
return x

return lbann.Dropout(
x, keep_prob=0.5,
x, keep_prob=self.dropout_keep_prob,
name='{0}_fc_drop{1}_instance{2}'.format(
self.name, i, self.instance))

Expand Down
16 changes: 12 additions & 4 deletions applications/physics/cosmology/cosmoflow/train_cosmoflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def create_python_dataset_reader(args):

readers = []
for role in ['train', 'val', 'test']:
role_dir = getattr(args, f'{role}_dir')
role_dir = getattr(args, f'{role}_dir', None)
if not role_dir:
continue
if role == 'val':
role = 'validate'
dataset = CosmoFlowDataset(role_dir, args.input_width, args.num_secrets)
reader = lbann.util.data.construct_python_dataset_reader(dataset, role=role)
readers.append(reader)
Expand Down Expand Up @@ -142,7 +146,7 @@ def create_synthetic_data_reader(input_width: int, num_responses: int) -> Any:
default_dir = '{}/{}'.format(default_lc_dataset, role)
parser.add_argument(
'--{}-dir'.format(role), action='store', type=str,
default=default_dir,
default=default_dir if role == 'train' else None,
help='the directory of the {} dataset'.format(role))
parser.add_argument(
'--synthetic', action='store_true',
Expand All @@ -156,6 +160,9 @@ def create_synthetic_data_reader(input_width: int, num_responses: int) -> Any:
parser.add_argument(
'--transform-input', action='store_true',
help='Apply log1p transformation to model inputs')
parser.add_argument(
'--dropout-keep-prob', action='store', type=float, default=0.5,
help='Probability of keeping activations in dropout layers (default: 0.5). Set to 1 to disable dropout')

# Parallelism arguments
parser.add_argument(
Expand Down Expand Up @@ -227,7 +234,8 @@ def create_synthetic_data_reader(input_width: int, num_responses: int) -> Any:
learning_rate=args.optimizer_learning_rate,
min_distconv_width=args.min_distconv_width,
mlperf=args.mlperf,
transform_input=args.transform_input)
transform_input=args.transform_input,
dropout_keep_prob=args.dropout_keep_prob)

# Add profiling callbacks if needed.
model.callbacks.extend(lbann.contrib.args.create_profile_callbacks(args))
Expand Down Expand Up @@ -274,7 +282,7 @@ def create_synthetic_data_reader(input_width: int, num_responses: int) -> Any:
environment['DISTCONV_JIT_CACHEPATH'] = f'{application_path}/DaCe_kernels/.dacecache'

if args.synthetic or args.no_datastore:
lbann_args = []
lbann_args = ['--num_io_threads=8']
else:
lbann_args = ['--use_data_store']
lbann_args += lbann.contrib.args.get_profile_args(args)
Expand Down

0 comments on commit 67fa8b8

Please sign in to comment.