Skip to content

Commit

Permalink
swap lon lat in all dataset / datamodules related code
Browse files Browse the repository at this point in the history
  • Loading branch information
liellnima committed Sep 3, 2024
1 parent f07a076 commit c723ee1
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 27 deletions.
19 changes: 10 additions & 9 deletions emulator/src/data/climate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from emulator.src.utils.utils import get_logger, map_variables_targetmip
from emulator.src.data.constants import (
LON,
LAT,
LON,
SEQ_LEN,
INPUT4MIPS_TEMP_RES,
CMIP6_TEMP_RES,
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
**ds_kwargs,
)

# this operates variable vise now....
# this operates on each variable (i.e. climateset and input4mips) now
def load_into_mem(
self,
paths: List[List[str]],
Expand All @@ -157,7 +157,7 @@ def load_into_mem(
).compute() # .compute is not necessary but eh, doesn't hurt
temp_data = (
temp_data.to_array().to_numpy()
) # Should be of shape (vars, years*ensemble_members*num_scenarios, lon, lat)
) # Should be of shape (vars, years*ensemble_members*num_scenarios, lat, lon)
array_list.append(temp_data)
temp_data = np.concatenate(array_list, axis=0)

Expand All @@ -184,16 +184,17 @@ def load_into_mem(
new_shape_one = int(temp_data.shape[1] / seq_len)

temp_data = temp_data.reshape(
num_vars, new_shape_one, seq_len, LON, LAT
) # num_vars, num_scenarios*num_remainding_years, seq_len,lon,lat)
num_vars, new_shape_one, seq_len, LAT, LON
) # num_vars, num_scenarios*num_remaining_years, seq_len, lat, lon)
if seq_to_seq == False:
temp_data = temp_data[:, :, -1, :, :] # only take last time step
temp_data = np.expand_dims(temp_data, axis=2)
if channels_last:
temp_data = temp_data.transpose((1, 2, 3, 4, 0))
else:
temp_data = temp_data.transpose((1, 2, 0, 3, 4))
return temp_data # (years*num_scenarios, seq_len, vars, lon, lat)

return temp_data # (years*num_scenarios, seq_len, vars, lat, lon)

def save_data_into_disk(
self, data: np.ndarray, fname: str, output_save_dir: str
Expand Down Expand Up @@ -296,7 +297,7 @@ def get_dataset_statistics(self, data, mode, type="z-norm", mips="cmip6"):
print("In testing mode, skipping statistics calculations.")

def get_mean_std(self, data):
# data shape (years*scenarios, seq, vars, lon, lat)
# data shape (years*scenarios, seq, vars, lat, lon)
if self.channels_last:
data = np.moveaxis(data, -1, 0)
else:
Expand Down Expand Up @@ -333,11 +334,11 @@ def normalize_data(self, data, stats, type="z-norm"):
if self.channels_last:
data = np.moveaxis(
data, -1, 0
) # vars from last to 0 (num_vars, years, seq_len, lon, lat)
) # vars from last to 0 (num_vars, years, seq_len, lat, lon)
else:
data = np.moveaxis(
data, 2, 0
) # shape (years, seq_len, num_vars, lon, lat) -> (num_vars, years, seq_len, lon, lat)
) # shape (years, seq_len, num_vars, lat, lon) -> (num_vars, years, seq_len, lat, lon)

print("mean", stats["mean"].shape, "std", stats["std"].shape)
norm_data = (data - stats["mean"]) / (stats["std"])
Expand Down
8 changes: 4 additions & 4 deletions emulator/src/datamodules/climate_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(
seq_len: int = SEQ_LEN_MAPPING[TEMP_RES],
output_save_dir: Optional[str] = DATA_DIR,
num_ensembles: int = 1, # 1 for first ensemble, -1 for all
lon: int = LON,
lat: int = LAT,
lon: int = LON,
num_levels: int = NUM_LEVELS,
name: str = "climate",
# input_transform: Optional[AbstractTransform] = None,
Expand Down Expand Up @@ -202,9 +202,9 @@ def _shared_eval_dataloader_kwargs(self) -> dict:
)

# resulting tensors sizes:
# x: (batch_size, sequence_length, lon, lat, in_vars) if channels_last else (batch_size, sequence_lenght, in_vars, lon, lat)
# y: (batch_size, sequence_length, lon, lat, out_vars) if channels_last else (batch_size, sequence_lenght, out_vars, lon, lat)
def train_dataloader(self):
# x: (batch_size, sequence_length, lat, lon, in_vars) if channels_last else (batch_size, sequence_lenght, in_vars, lat, lon)
# y: (batch_size, sequence_length, lat, lon, out_vars) if channels_last else (batch_size, sequence_lenght, out_vars, lat, lon)
def train_dataloader(self): # TODO: does this really give us the right shape?
return DataLoader(
dataset=self._data_train,
batch_size=self.hparams.batch_size,
Expand Down
14 changes: 7 additions & 7 deletions emulator/src/datamodules/dummy_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(
out_var_ids: List[str] = ["pr", "tas"],
seq_len: int = 10,
seq_to_seq: bool = True, # if true maps from T->T else from T->1
lon: int = 32,
lat: int = 32,
lon: int = 32,
num_levels: int = 1,
channels_last: bool = True, # wheather variables come last our after sequence lenght
batch_size: int = 16,
Expand All @@ -66,9 +66,9 @@ def __init__(
out_var_ids: Lsit(str): Ids of output variables.
seq_len (int): Lenght of the input sequence (in time).
seq_to_seq (bool): If true maps from seq_len to seq_len else from seq_len to- 1.
lon (int): Longitude of grid.
lat (int): Latitude of grid.
channels_last (int): If true, shape of tensors (batch_size, time, lon, lat, channels) else (batch_size, time, channels, lon, lat). Important for some torch layers.
lon (int): Longitude of grid.
channels_last (int): If true, shape of tensors (batch_size, time, lat, lon, channels) else (batch_size, time, channels, lat, lon). Important for some torch layers.
size (int): Size (num examples) of the dummy dataset.
test_split (float): Fraction of data to use for testing.
val_split (float): Fraction of data to use for evaluation.
Expand Down Expand Up @@ -115,17 +115,17 @@ def setup(self, stage: Optional[str] = None):
size=(
self.hparams.size,
self.hparams.seq_len,
self.hparams.lon,
self.hparams.lat,
self.hparams.lon,
len(self.hparams.in_var_ids),
)
)
targets = torch.ones(
size=(
self.hparams.size,
self.out_seq_len,
self.hparams.lon,
self.hparams.lat,
self.hparams.lon,
len(self.hparams.out_var_ids),
)
)
Expand All @@ -136,17 +136,17 @@ def setup(self, stage: Optional[str] = None):
self.hparams.size,
self.hparams.seq_len,
len(self.hparams.in_var_ids),
self.hparams.lon,
self.hparams.lat,
self.hparams.lon,
)
)
targets = torch.ones(
size=(
self.hparams.size,
self.out_seq_len,
len(self.hparams.out_var_ids),
self.hparams.lon,
self.hparams.lat,
self.hparams.lon,
)
)

Expand Down
2 changes: 1 addition & 1 deletion emulator/src/datamodules/super_climate_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ def __init__(
data_dir: Optional[str] = DATA_DIR,
output_save_dir: Optional[str] = DATA_DIR,
num_ensembles: int = 1,
lon: int = LON,
lat: int = LAT,
lon: int = LON,
num_levels: int = NUM_LEVELS,
name: str = "super_climate"
):
Expand Down
11 changes: 5 additions & 6 deletions emulator/src/datamodules/template_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def __init__(
], # do we want to implement keeping only certain years for testing?
seq_len: int = 10,
seq_to_seq: bool = True, # if true maps from T->T else from T->1
lon: int = 32,
lat: int = 32,
lon: int = 32,
num_levels: int = 1,
channels_last: bool = True, # wheather variables come last our after sequence lenght
channels_last: bool = True, # whether variables come last or after the sequence length
train_scenarios: List[str] = ["historical", "ssp126"],
test_scenarios: List[str] = ["ssp345"],
val_scenarios: List[str] = ["ssp119"],
Expand All @@ -70,9 +70,8 @@ def __init__(
out_var_ids: Lsit(str): Ids of output variables.
seq_len (int): Lenght of the input sequence (in time).
seq_to_seq (bool): If true maps from seq_len to seq_len else from seq_len to- 1.
lon (int): Longitude of grid.
lat (int): Latitude of grid.
lon (int): Longitude of grid.
batch_size (int): Batch size for the training dataloader
eval_batch_size (int): Batch size for the test and validation dataloader's
num_workers (int): Dataloader arg for higher efficiency
Expand Down Expand Up @@ -143,8 +142,8 @@ def _shared_eval_dataloader_kwargs(self) -> dict:

# Probably we also just want a list of Train Dataloaders not just a single one so we can swith sets in our memory
# resulting tensors sizes:
# x: (batch_size, sequence_length, lon, lat, in_vars) if channels_last else (batch_size, sequence_lenght, in_vars, lon, lat)
# y: (batch_size, sequence_length, lon, lat, out_vars) if channels_last else (batch_size, sequence_lenght, out_vars, lon, lat)
# x: (batch_size, sequence_length, lat, lon, in_vars) if channels_last else (batch_size, sequence_lenght, in_vars, lat, lon)
# y: (batch_size, sequence_length, lat, lon, out_vars) if channels_last else (batch_size, sequence_lenght, out_vars, lat, lon)
def train_dataloader(self):
return DataLoader(
dataset=self._data_train,
Expand Down

0 comments on commit c723ee1

Please sign in to comment.