Skip to content

Commit

Permalink
Merge pull request #27 from clessig/develop
Browse files Browse the repository at this point in the history
merge all new developments to main
  • Loading branch information
iluise authored Aug 12, 2024
2 parents bf8628d + e3412a6 commit ac3c433
Show file tree
Hide file tree
Showing 32 changed files with 1,534 additions and 2,320 deletions.
6 changes: 1 addition & 5 deletions atmorep/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@

fpath = os.path.dirname(os.path.realpath(__file__))

year_base = 1979
year_last = 2022

path_models = Path( fpath, '../../models/')
path_results = Path( fpath, '../../results/')
path_data = Path( fpath, '../../data/')
path_results = Path( fpath, '../../results')
path_plots = Path( fpath, '../results/plots/')

grib_index = { 'vorticity' : 'vo', 'divergence' : 'd', 'geopotential' : 'z',
Expand Down
194 changes: 103 additions & 91 deletions atmorep/core/atmorep_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from atmorep.transformer.transformer_decoder import TransformerDecoder
from atmorep.transformer.tail_ensemble import TailEnsemble


####################################################################################################
class AtmoRepData( torch.nn.Module) :

Expand All @@ -53,37 +54,6 @@ def __init__( self, net) :
self.rng_seed = net.cf.rng_seed
if not self.rng_seed :
self.rng_seed = int(torch.randint( 100000000, (1,)))

###################################################
def load_data( self, mode : NetMode, batch_size = -1, num_loader_workers = -1) :
'''Load data'''

cf = self.net.cf

if batch_size < 0 :
batch_size = cf.batch_size_max
if num_loader_workers < 0 :
num_loader_workers = cf.num_loader_workers

if mode == NetMode.train :
self.data_loader_train = self._load_data( self.dataset_train, batch_size, num_loader_workers)
elif mode == NetMode.test :
batch_size = cf.batch_size_test
self.data_loader_test = self._load_data( self.dataset_test, batch_size, num_loader_workers)
else :
assert False

###################################################
def _load_data( self, dataset, batch_size, num_loader_workers) :
'''Private implementation for load'''

dataset.load_data( batch_size)

loader_params = { 'batch_size': None, 'batch_sampler': None, 'shuffle': False,
'num_workers': num_loader_workers, 'pin_memory': True}
data_loader = torch.utils.data.DataLoader( dataset, **loader_params, sampler = None)

return data_loader

###################################################
def set_data( self, mode : NetMode, times_pos, batch_size = -1, num_loader_workers = -1) :
Expand All @@ -94,7 +64,7 @@ def set_data( self, mode : NetMode, times_pos, batch_size = -1, num_loader_worke

dataset = self.dataset_train if mode == NetMode.train else self.dataset_test
dataset.set_data( times_pos, batch_size)

self._set_data( dataset, mode, batch_size, num_loader_workers)

###################################################
Expand All @@ -103,7 +73,6 @@ def set_global( self, mode : NetMode, times, batch_size = -1, num_loader_workers
cf = self.net.cf
if batch_size < 0 :
batch_size = cf.batch_size_train if mode == NetMode.train else cf.batch_size_test

dataset = self.dataset_train if mode == NetMode.train else self.dataset_test
dataset.set_global( times, batch_size, cf.token_overlap)

Expand Down Expand Up @@ -143,7 +112,7 @@ def _set_data( self, dataset, mode : NetMode, batch_size = -1, loader_workers =
assert False

###################################################
def normalizer( self, field, vl_idx) :
def normalizer( self, field, vl_idx, lats_idx, lons_idx ) :

if isinstance( field, str) :
for fidx, field_info in enumerate(self.cf.fields) :
Expand All @@ -153,12 +122,15 @@ def normalizer( self, field, vl_idx) :
normalizer = self.dataset_train.datasets[fidx].normalizer

elif isinstance( field, int) :
normalizer = self.dataset_train.datasets[field][vl_idx].normalizer

normalizer = self.dataset_train.normalizers[field][vl_idx]
if len(normalizer.shape) > 2:
normalizer = np.take( np.take( normalizer, lats_idx, -2), lons_idx, -1)
else :
assert False, 'invalid argument type (has to be index to cf.fields or field name)'

year_base = self.dataset_train.year_base

return normalizer
return normalizer, year_base

###################################################
def mode( self, mode : NetMode) :
Expand Down Expand Up @@ -193,8 +165,8 @@ def forward( self, xin) :
return pred

###################################################
def get_attention( self, xin): #, field_idx) :
attn = self.net.get_attention( xin) #, field_idx)
def get_attention( self, xin) :
attn = self.net.get_attention( xin)
return attn

###################################################
Expand All @@ -208,40 +180,26 @@ def create( self, pre_batch, devices, create_net = True, pre_batch_targets = Non
self.pre_batch_targets = pre_batch_targets

cf = self.net.cf
self.dataset_train = MultifieldDataSampler( cf.data_dir, cf.years_train, cf.fields,
batch_size = cf.batch_size_start,
num_t_samples = cf.num_t_samples,
num_patches_per_t = cf.num_patches_per_t_train,
num_load = cf.num_files_train,
pre_batch = self.pre_batch,
rng_seed = self.rng_seed,
file_shape = cf.file_shape,
smoothing = cf.data_smoothing,
level_type = cf.level_type,
file_format = cf.file_format,
month = cf.month,
time_sampling = cf.time_sampling,
geo_range = cf.geo_range_sampling,
fields_targets = cf.fields_targets,
pre_batch_targets = self.pre_batch_targets )

self.dataset_test = MultifieldDataSampler( cf.data_dir, cf.years_test, cf.fields,
batch_size = cf.batch_size_test,
num_t_samples = cf.num_t_samples,
num_patches_per_t = cf.num_patches_per_t_test,
num_load = cf.num_files_test,
pre_batch = self.pre_batch,
rng_seed = self.rng_seed,
file_shape = cf.file_shape,
smoothing = cf.data_smoothing,
level_type = cf.level_type,
file_format = cf.file_format,
month = cf.month,
time_sampling = cf.time_sampling,
geo_range = cf.geo_range_sampling,
lat_sampling_weighted = cf.lat_sampling_weighted,
fields_targets = cf.fields_targets,
pre_batch_targets = self.pre_batch_targets )
loader_params = { 'batch_size': None, 'batch_sampler': None, 'shuffle': False,
'num_workers': cf.num_loader_workers, 'pin_memory': True}

self.dataset_train = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_train,
cf.batch_size,
pre_batch, cf.n_size, cf.num_samples_per_epoch,
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True,
compute_weights = (cf.losses.count('weighted_mse') > 0) )
self.data_loader_train = torch.utils.data.DataLoader( self.dataset_train, **loader_params,
sampler = None)

self.dataset_test = MultifieldDataSampler( cf.file_path, cf.fields, cf.years_val,
cf.batch_size_validation,
pre_batch, cf.n_size, cf.num_samples_validate,
with_shuffle = (cf.BERT_strategy != 'global_forecast'),
with_source_idxs = True,
compute_weights = (cf.losses.count('weighted_mse') > 0) )
self.data_loader_test = torch.utils.data.DataLoader( self.dataset_test, **loader_params,
sampler = None)

return self

Expand All @@ -261,7 +219,6 @@ def create( self, devices, load_pretrained=True) :

cf = self.cf
self.devices = devices
size_token_info = 6
self.fields_coupling_idx = []

self.fields_index = {}
Expand Down Expand Up @@ -294,17 +251,9 @@ def create( self, devices, load_pretrained=True) :

self.embeds = torch.nn.ModuleList()
self.encoders = torch.nn.ModuleList()
self.masks = torch.nn.ParameterList()

for field_idx, field_info in enumerate(cf.fields) :

# learnabl class token
if cf.learnable_mask :
mask = torch.nn.Parameter( 0.1 * torch.randn( np.prod( field_info[4]), requires_grad=True))
self.masks.append( mask.to(devices[0]))
else :
self.masks.append( None)

# encoder
self.encoders.append( TransformerEncoder( cf, field_idx, True).create())
# load pre-trained model if specified
Expand Down Expand Up @@ -356,11 +305,10 @@ def create( self, devices, load_pretrained=True) :
device = self.devices[0]
if len(field_info[1]) > 3 :
assert field_info[1][3] < 4, 'Only single node model parallelism supported'
print(devices, field_info[1][3])
assert field_info[1][3] < len(devices), 'Per field device id larger than max devices'
device = self.devices[ field_info[1][3] ]
# set device
if self.masks[field_idx] != None :
self.masks[field_idx].to(device)
self.embeds[field_idx].to(device)
self.encoders[field_idx].to(device)

Expand Down Expand Up @@ -418,6 +366,68 @@ def load_block( self, field_info, block_name, block ) :
print( 'Loaded {} for {} from id = {} (ignoring/missing {} elements).'.format( block_name,
field_info[0], field_info[1][4][0], len(mkeys) ) )

###################################################
def translate_weights(self, mloaded, mkeys, ukeys):
'''
Function used for backward compatibility
'''
cf = self.cf

#encoder:
for layer in range(cf.encoder_num_layers) :

#shape([16, 3, 128, 2048])
mw = torch.cat([mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_{k}.weight'] for head in range(cf.encoder_num_heads) for k in ["qs", "ks", "vs"]])
mloaded[f'encoders.0.heads.{layer}.proj_heads.weight'] = mw

for head in range(cf.encoder_num_heads):
del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_qs.weight']
del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_ks.weight']
del mloaded[f'encoders.0.heads.{layer}.heads_self.{head}.proj_vs.weight']

#cross attention
if f'encoders.0.heads.{layer}.heads_other.0.proj_qs.weight' in ukeys:
mw = torch.cat([mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_{k}.weight'] for head in range(cf.encoder_num_heads) for k in ["qs", "ks", "vs"]])

for i in range(cf.encoder_num_heads):
del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_qs.weight']
del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_ks.weight']
del mloaded[f'encoders.0.heads.{layer}.heads_other.{head}.proj_vs.weight']

else:
dim_mw = self.encoders[0].heads[0].proj_heads_other[0].weight.shape
mw = torch.tensor(np.zeros(dim_mw))

mloaded[f'encoders.0.heads.{layer}.proj_heads_other.0.weight'] = mw

#decoder
for iblock in range(0, 19, 2) :
mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads.{head}.proj_{k}.weight'] for head in range(8) for k in ["qs", "ks", "vs"]])
mloaded[f'decoders.0.blocks.{iblock}.proj_heads.weight'] = mw

qs = [mloaded[f'decoders.0.blocks.{iblock}.heads_other.{head}.proj_qs.weight'] for head in range(8)]
mw = torch.cat([mloaded[f'decoders.0.blocks.{iblock}.heads_other.{head}.proj_{k}.weight'] for head in range(8) for k in ["ks", "vs"]])

mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_q.weight'] = torch.cat([*qs])
mloaded[f'decoders.0.blocks.{iblock}.proj_heads_o_kv.weight'] = mw

#self.num_samples_validate
decoder_dim = self.decoders[0].blocks[iblock].ln_q.weight.shape #128
mloaded[f'decoders.0.blocks.{iblock}.ln_q.weight'] = torch.tensor(np.ones(decoder_dim))
mloaded[f'decoders.0.blocks.{iblock}.ln_k.weight'] = torch.tensor(np.ones(decoder_dim))
mloaded[f'decoders.0.blocks.{iblock}.ln_q.bias'] = torch.tensor(np.ones(decoder_dim))
mloaded[f'decoders.0.blocks.{iblock}.ln_k.bias'] = torch.tensor(np.ones(decoder_dim))

for i in range(8):
del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_qs.weight']
del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_ks.weight']
del mloaded[f'decoders.0.blocks.{iblock}.heads.{i}.proj_vs.weight']
del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_qs.weight']
del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_ks.weight']
del mloaded[f'decoders.0.blocks.{iblock}.heads_other.{i}.proj_vs.weight']

return mloaded

###################################################
@staticmethod
def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) :
Expand All @@ -429,15 +439,18 @@ def load( model_id, devices, cf = None, epoch = -2, load_pretrained=False) :

model = AtmoRep( cf).create( devices, load_pretrained=False)
mloaded = torch.load( utils.get_model_filename( model, model_id, epoch) )
mkeys, _ = model.load_state_dict( mloaded, False )
mkeys, ukeys = model.load_state_dict( mloaded, False )
if (f'encoders.0.heads.0.proj_heads.weight') in mkeys:
mloaded = model.translate_weights(mloaded, mkeys, ukeys)
mkeys, ukeys = model.load_state_dict( mloaded, False )

if len(mkeys) > 0 :
print( f'Loaded AtmoRep: ignoring {len(mkeys)} elements: {mkeys}')

# TODO: remove, only for backward
if model.embeds_token_info[0].weight.abs().max() == 0. :
model.embeds_token_info = torch.nn.ModuleList()

return model

###################################################
Expand Down Expand Up @@ -474,8 +487,9 @@ def forward( self, xin) :

# embedding
cf = self.cf

fields_embed = self.get_fields_embed(xin)

# attention maps (if requested)
atts = [ [] for _ in cf.fields ]

Expand Down Expand Up @@ -528,16 +542,14 @@ def forward_encoder_block( self, iblock, fields_embed) :
return fields_embed_cur, atts

###################################################

def get_fields_embed( self, xin ) :
cf = self.cf
if 0 == len(self.embeds_token_info) : # TODO: only for backward compatibility, remove
emb_net_ti = self.embed_token_info
return [prepare_token( field_data, emb_net, emb_net_ti, cf.with_cls )
return [prepare_token( field_data, emb_net, emb_net_ti )
for fidx,(field_data,emb_net) in enumerate(zip( xin, self.embeds))]
else :
embs_net_ti = self.embeds_token_info
return [prepare_token( field_data, emb_net, embs_net_ti[fidx], cf.with_cls )
return [prepare_token( field_data, emb_net, embs_net_ti[fidx] )
for fidx,(field_data,emb_net) in enumerate(zip( xin, self.embeds))]

###################################################
Expand Down
Loading

0 comments on commit ac3c433

Please sign in to comment.