Skip to content

Commit

Permalink
Merge pull request #24 from clessig/iluise/head
Browse files Browse the repository at this point in the history
Iluise/head
  • Loading branch information
iluise authored Aug 12, 2024
2 parents 224bf91 + ba60b71 commit db4d77f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 65 deletions.
10 changes: 6 additions & 4 deletions atmorep/core/atmorep_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,18 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non
self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train,
cf.batch_size,
pre_batch, cf.n_size, cf.num_samples_per_epoch,
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True )
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True,
compute_weights = (cf.losses.count('weighted_mse') > 0) )
self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params,
sampler = None)

self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_val,
cf.batch_size_validation,
pre_batch, cf.n_size, cf.num_samples_validate,
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True )
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True,
compute_weights = (cf.losses.count('weighted_mse') > 0) )
self.data_loader_test = torch.utils.data.DataLoader( self.dataset_test, **loader_params,
sampler = None)

Expand Down
15 changes: 7 additions & 8 deletions atmorep/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ def parse_args( cf, args) :
@staticmethod
def run( cf, model_id, model_epoch, devices) :

if not hasattr(cf, 'batch_size'):
cf.batch_size = cf.batch_size_max
if not hasattr(cf, 'batch_size_validation'):
cf.batch_size_validation = cf.batch_size_max

cf.with_mixed_precision = True

# set/over-write options as desired
Expand All @@ -82,7 +77,7 @@ def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) :
else :
num_accs_per_task = int( 4 / int( os.environ.get('SLURM_TASKS_PER_NODE', '1')[0] ))
devices = init_torch( num_accs_per_task)
#devices = ['cuda']
#devices = ['cuda:1']

par_rank, par_size = setup_ddp( with_ddp)
cf = Config().load_json( model_id)
Expand Down Expand Up @@ -112,6 +107,12 @@ def evaluate( mode, model_id, file_path, args = {}, model_epoch=-2) :
cf.with_mixed_precision = False
if not hasattr(cf, 'with_pytest'):
cf.with_pytest = False
if not hasattr(cf, 'batch_size'):
cf.batch_size = cf.batch_size_max
if not hasattr(cf, 'batch_size_validation'):
cf.batch_size_validation = cf.batch_size_max
if not hasattr(cf, 'years_val'):
cf.years_val = cf.years_test

func = getattr( Evaluator, mode)
func( cf, model_id, model_epoch, devices, args)
Expand Down Expand Up @@ -159,8 +160,6 @@ def global_forecast( cf, model_id, model_epoch, devices, args = {}) :
cf.batch_size = 196 #14
if not hasattr(cf, 'batch_size_validation'):
cf.batch_size_validation = 1 #64
if not hasattr(cf, 'batch_size_delta'):
cf.batch_size_delta = 8
if not hasattr(cf, 'num_samples_validate'):
cf.num_samples_validate = 196
#if not hasattr(cf,'with_mixed_precision'):
Expand Down
60 changes: 30 additions & 30 deletions atmorep/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@
####################################################################################################

import torch
import numpy as np
import os
import sys
import pdb
import traceback

import pdb
import wandb

import atmorep.config.config as config
from atmorep.core.trainer import Trainer_BERT
from atmorep.utils.utils import Config
from atmorep.utils.utils import setup_ddp
from atmorep.utils.utils import setup_wandb
from atmorep.utils.utils import init_torch
import atmorep.utils.utils as utils


####################################################################################################
Expand Down Expand Up @@ -110,40 +106,43 @@ def train() :
# [ total masking rate, rate masking, rate noising, rate for multi-res distortion]
# ]

cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ],
[ 96, 105, 114, 123, 137 ],
[12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ]
# cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ],
# [ 96, 105, 114, 123, 137 ],
# [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ]
# cf.fields_prediction = [ [cf.fields[0][0], 1.] ]

cf.fields = [ [ 'velocity_u', [ 1, 1024, [ ], 0 ],
[ 96, 105, 114, 123, 137 ],
[12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ] ]

cf.fields_prediction = [ [cf.fields[0][0], 1.] ]

# cf.fields = [ [ 'velocity_u', [ 1, 2048, [ ], 0],
# [ 96, 105, 114, 123, 137 ],
# [12, 6, 12], [3, 9, 9], [0.5, 0.9, 0.1, 0.05] ] ]

# cf.fields = [ [ 'velocity_v', [ 1, 2048, [ ], 0 ],

# cf.fields = [ [ 'velocity_v', [ 1, 1024, [ ], 0 ],
# [ 96, 105, 114, 123, 137 ],
# [ 12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ]
# [12, 3, 6], [3, 18, 18], [0.25, 0.9, 0.1, 0.05] ] ]

# cf.fields = [ [ 'velocity_z', [ 1, 1024, [ ], 0 ],
# cf.fields = [ [ 'velocity_z', [ 1, 512, [ ], 0 ],
# [ 96, 105, 114, 123, 137 ],
# [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ]

# cf.fields = [ [ 'specific_humidity', [ 1, 2048, [ ], 0 ],
# cf.fields = [ [ 'specific_humidity', [ 1, 1024, [ ], 0 ],
# [ 96, 105, 114, 123, 137 ],
# [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05] ] ]
# [12, 6, 12], [3, 9, 9], [0.25, 0.9, 0.1, 0.05], 'local' ] ]
# [12, 2, 4], [3, 27, 27], [0.5, 0.9, 0.1, 0.05], 'local' ] ]

cf.fields_targets = []

cf.years_train = [2021] # list( range( 1980, 2018))
cf.years_train = list( range( 2010, 2021))
cf.years_val = [2021] #[2018]
cf.month = None
cf.geo_range_sampling = [[ -90., 90.], [ 0., 360.]]
cf.time_sampling = 1 # sampling rate for time steps
# random seeds
cf.torch_seed = torch.initial_seed()
# training params
cf.batch_size_validation = 64
cf.batch_size = 32
cf.batch_size_validation = 1 #64
cf.batch_size = 96
cf.num_epochs = 128
cf.num_samples_per_epoch = 4096*12
cf.num_samples_validate = 128*12
Expand All @@ -161,12 +160,12 @@ def train() :
cf.dropout_rate = 0.05
cf.with_qk_lnorm = False
# encoder
cf.encoder_num_layers = 4
cf.encoder_num_layers = 6
cf.encoder_num_heads = 16
cf.encoder_num_mlp_layers = 2
cf.encoder_att_type = 'dense'
# decoder
cf.decoder_num_layers = 4
cf.decoder_num_layers = 6
cf.decoder_num_heads = 16
cf.decoder_num_mlp_layers = 2
cf.decoder_self_att = False
Expand All @@ -177,19 +176,19 @@ def train() :
cf.net_tail_num_nets = 16
cf.net_tail_num_layers = 0
# loss
cf.losses = ['mse_ensemble', 'stats'] # mse, mse_ensemble, stats, crps
cf.losses = ['mse_ensemble', 'stats'] # mse, mse_ensemble, stats, crps, weighted_mse
# training
cf.optimizer_zero = False
cf.lr_start = 5. * 10e-7
cf.lr_max = 0.00005
cf.lr_min = 0.00002
cf.weight_decay = 0.1
cf.lr_max = 0.00005*3
cf.lr_min = 0.00004 #0.00002
cf.weight_decay = 0.05 #0.1
cf.lr_decay_rate = 1.025
cf.lr_start_epochs = 3
# BERT
# strategies: 'BERT', 'forecast', 'temporal_interpolation'
cf.BERT_strategy = 'BERT'
cf.forecast_num_tokens = 1 # only needed / used for BERT_strategy 'forecast
cf.forecast_num_tokens = 2 # only needed / used for BERT_strategy 'forecast
cf.BERT_fields_synced = False # apply synchronized / identical masking to all fields
# (fields need to have same BERT params for this to have effect)
cf.BERT_mr_max = 2 # maximum reduction rate for resolution
Expand Down Expand Up @@ -219,12 +218,13 @@ def train() :
# # # cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res025_chunk8.zarr'
# # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk8_lat180_lon180.zarr'
# # # cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res025_chunk16.zarr'
cf.file_path = '/gpfs/scratch/ehpc03/era5_y2010_2021_res025_chunk8.zarr/'
# # # in steps x lat_degrees x lon_degrees
# cf.n_size = [36, 0.25*9*6, 0.25*9*12]
cf.n_size = [36, 0.25*9*6, 0.25*9*12]

# cf.file_path = '/ec/res4/scratch/nacl/atmorep/era5_y2021_res100_chunk16.zarr'
cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr'
cf.n_size = [36, 1*9*6, 1.*9*12]
#cf.file_path = '/p/scratch/atmo-rep/data/era5_1deg/months/era5_y2021_res100_chunk16.zarr'
#cf.n_size = [36, 1*9*6, 1.*9*12]

if cf.with_wandb and 0 == cf.par_rank :
cf.write_json( wandb)
Expand Down
47 changes: 24 additions & 23 deletions atmorep/datasets/multifield_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class MultifieldDataSampler( torch.utils.data.IterableDataset):

###################################################
def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size,
num_samples, with_shuffle = False, time_sampling = 1, with_source_idxs = False,
num_samples, with_shuffle = False, time_sampling = 1, with_source_idxs = False, compute_weights = False,
fields_targets = None, pre_batch_targets = None ) :
'''
Data set for single dynamic field at an arbitrary number of vertical levels
Expand All @@ -46,6 +46,7 @@ def __init__( self, file_path, fields, years, batch_size, pre_batch, n_size,
self.n_size = n_size
self.num_samples = num_samples
self.with_source_idxs = with_source_idxs
self.compute_weights = compute_weights
self.with_shuffle = with_shuffle
self.pre_batch = pre_batch

Expand Down Expand Up @@ -185,11 +186,11 @@ def __iter__(self):

source_data, tok_info = [], []
# extract data, normalize and tokenize
cdata = data_t[ : , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]]
cdata = data_t[ ... , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]]

normalizer = self.normalizers[ifield][ilevel]
if corr_type != 'global':
normalizer = normalizer[ : , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]]
if corr_type != 'global':
normalizer = normalizer[ ... , lat_ran[:,np.newaxis], lon_ran[np.newaxis,:]]
cdata = normalize(cdata, normalizer, sources_infos[-1][0], year_base = self.year_base)

source_data = tokenize( torch.from_numpy( cdata), tok_size )
Expand Down Expand Up @@ -217,29 +218,29 @@ def __iter__(self):

tmidx_list = sources[-1]
weights_idx_list = []
if self.compute_weights:
for ifield, field_info in enumerate(self.fields):
weights = []
for ilevel, vl in enumerate(field_info[2]):
for ibatch in range(self.batch_size):

lats_idx = source_idxs[ibatch][1]
lons_idx = source_idxs[ibatch][2]

for ifield, field_info in enumerate(self.fields):
weights = []
for ilevel, vl in enumerate(field_info[2]):
for ibatch in range(self.batch_size):

lats_idx = source_idxs[ibatch][1]
lons_idx = source_idxs[ibatch][2]

idx_base = tmidx_list[ifield][ilevel][ibatch]
idx_loc = idx_base - np.prod(num_tokens) * ibatch

grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1
grid = torch.from_numpy( np.array( np.broadcast_to( grid,
shape = [tok_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1))
idx_base = tmidx_list[ifield][ilevel][ibatch]
idx_loc = idx_base - np.prod(num_tokens) * ibatch

grid = np.flip(np.array( np.meshgrid( lons_idx, lats_idx)), axis = 0) #flip to have lat on pos 0 and lon on pos 1
grid = torch.from_numpy( np.array( np.broadcast_to( grid,
shape = [tok_size[0]*num_tokens[0], *grid.shape])).swapaxes(0,1))

grid_lats_toked = tokenize( grid[0], tok_size).flatten( 0, 2)
grid_lats_toked = tokenize( grid[0], tok_size).flatten( 0, 2)

lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()])
lats_mskd_b = np.array([np.unique(t) for t in grid_lats_toked[ idx_loc ].numpy()])

weights.append([get_weights(la) for la in lats_mskd_b])
weights.append([get_weights(la) for la in lats_mskd_b])

weights_idx_list.append(weights)
weights_idx_list.append(weights)
sources = (*sources, weights_idx_list)

# TODO: implement (only required when prediction target comes from different data stream)
Expand Down

0 comments on commit db4d77f

Please sign in to comment.