From c9267eba536b4daa76f769a0091eadec040e57c6 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 30 Aug 2023 08:59:28 -0600 Subject: [PATCH 1/9] patch for using dual datahandler with hr_spatial_coarsen > 1. Before this broke the regridding bc lat_lon wasn't being updated after hr_spatial_coarsening. --- sup3r/pipeline/forward_pass.py | 26 +- sup3r/preprocessing/data_handling/base.py | 557 ++++++++---------- .../data_handling/dual_data_handling.py | 23 +- sup3r/preprocessing/dual_batch_handling.py | 101 +--- sup3r/utilities/regridder.py | 347 +++++------ 5 files changed, 435 insertions(+), 619 deletions(-) diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 863e1c920..bcf86fc7e 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -19,21 +19,16 @@ import sup3r.bias.bias_transforms import sup3r.models -from sup3r.postprocessing.file_handling import ( - OutputHandler, - OutputHandlerH5, - OutputHandlerNC, -) +from sup3r.postprocessing.file_handling import (OutputHandler, OutputHandlerH5, + OutputHandlerNC) from sup3r.preprocessing.data_handling import ExogenousDataHandler from sup3r.preprocessing.data_handling.base import InputMixIn from sup3r.utilities import ModuleName from sup3r.utilities.cli import BaseCLI from sup3r.utilities.execution import DistributedProcess -from sup3r.utilities.utilities import ( - get_chunk_slices, - get_input_handler_class, - get_source_type, -) +from sup3r.utilities.utilities import (get_chunk_slices, + get_input_handler_class, + get_source_type) np.random.seed(42) @@ -1167,8 +1162,7 @@ def update_input_handler_kwargs(self, strategy): cache_pattern=self.cache_pattern, single_ts_files=self.single_ts_files, handle_features=strategy.handle_features, - val_split=0.0, - ) + val_split=0.0) input_handler_kwargs.update(fwp_input_handler_kwargs) return input_handler_kwargs @@ -1356,11 +1350,9 @@ def lr_crop_slice(self): @property def chunk_shape(self): """Get shape for the current padded spatiotemporal chunk""" - return ( - self.lr_pad_slice[0].stop - self.lr_pad_slice[0].start, - self.lr_pad_slice[1].stop - self.lr_pad_slice[1].start, - self.ti_pad_slice.stop - self.ti_pad_slice.start, - ) + return (self.lr_pad_slice[0].stop - self.lr_pad_slice[0].start, + self.lr_pad_slice[1].stop - self.lr_pad_slice[1].start, + self.ti_pad_slice.stop - self.ti_pad_slice.start) @property def cache_pattern(self): diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index bb5c15041..d2251f3c8 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -17,41 +17,24 @@ from scipy.spatial import KDTree from sup3r.bias.bias_transforms import get_spatial_bc_factors -from sup3r.preprocessing.data_handling.mixin import ( - InputMixIn, - TrainingPrepMixIn, -) +from sup3r.preprocessing.data_handling.mixin import (InputMixIn, + TrainingPrepMixIn, + ) from sup3r.preprocessing.feature_handling import ( - BVFreqMon, - BVFreqSquaredNC, - Feature, - FeatureHandler, - InverseMonNC, - LatLonNC, - PotentialTempNC, - PressureNC, - Rews, - Shear, - TempNC, - UWind, - VWind, - WinddirectionNC, - WindspeedNC, + BVFreqMon, BVFreqSquaredNC, Feature, FeatureHandler, InverseMonNC, + LatLonNC, PotentialTempNC, PressureNC, Rews, Shear, TempNC, UWind, VWind, + WinddirectionNC, WindspeedNC, ) from sup3r.utilities import ModuleName from sup3r.utilities.cli import BaseCLI from sup3r.utilities.interpolation import Interpolator -from sup3r.utilities.utilities import ( - estimate_max_workers, - get_chunk_slices, - get_raster_shape, - np_to_pd_times, - spatial_coarsening, - uniform_box_sampler, - uniform_time_sampler, - weighted_box_sampler, - weighted_time_sampler, -) +from sup3r.utilities.utilities import (estimate_max_workers, get_chunk_slices, + get_raster_shape, np_to_pd_times, + spatial_coarsening, uniform_box_sampler, + uniform_time_sampler, + weighted_box_sampler, + weighted_time_sampler, + ) np.random.seed(42) @@ -70,12 +53,9 @@ class DataHandler(FeatureHandler, InputMixIn, TrainingPrepMixIn): # model but are not part of the synthetic output and are not sent to the # discriminator. These are case-insensitive and follow the Unix shell-style # wildcard format. - TRAIN_ONLY_FEATURES = ( - 'BVF*', - 'inversemoninobukhovlength_*', - 'RMOL', - 'topography', - ) + TRAIN_ONLY_FEATURES = ('BVF*', 'inversemoninobukhovlength_*', 'RMOL', + 'topography', + ) def __init__(self, file_paths, @@ -230,9 +210,8 @@ def __init__(self, temporal_slice=temporal_slice) self.file_paths = file_paths - self.features = ( - features if isinstance(features, (list, tuple)) else [features] - ) + self.features = (features if isinstance(features, + (list, tuple)) else [features]) self.val_time_index = None self.max_delta = max_delta self.val_split = val_split @@ -265,35 +244,30 @@ def __init__(self, self._norm_workers = self.worker_kwargs.get('norm_workers', None) self._load_workers = self.worker_kwargs.get('load_workers', None) self._compute_workers = self.worker_kwargs.get('compute_workers', None) - self._worker_attrs = ['_ti_workers', - '_norm_workers', - '_compute_workers', - '_extract_workers', - '_load_workers'] + self._worker_attrs = [ + '_ti_workers', '_norm_workers', '_compute_workers', + '_extract_workers', '_load_workers' + ] self.preflight() - overwrite = (self.overwrite_cache - and self.cache_files is not None + overwrite = (self.overwrite_cache and self.cache_files is not None and all(os.path.exists(fp) for fp in self.cache_files)) if self.try_load and self.load_cached: - logger.info( - f'All {self.cache_files} exist. Loading from cache ' - f'instead of extracting from source files.') + logger.info(f'All {self.cache_files} exist. Loading from cache ' + f'instead of extracting from source files.') self.load_cached_data() elif self.try_load and not self.load_cached: self.clear_data() - logger.info( - f'All {self.cache_files} exist. Call ' - 'load_cached_data() or use load_cache=True to load ' - 'this data from cache files.') + logger.info(f'All {self.cache_files} exist. Call ' + 'load_cached_data() or use load_cache=True to load ' + 'this data from cache files.') else: if overwrite: - logger.info( - f'{self.cache_files} exists but overwrite_cache ' - 'is set to True. Proceeding with extraction.') + logger.info(f'{self.cache_files} exists but overwrite_cache ' + 'is set to True. Proceeding with extraction.') self._raster_size_check() self._run_data_init_if_needed() @@ -306,21 +280,25 @@ def __init__(self, if mask_nan and self.data is not None: nan_mask = np.isnan(self.data).any(axis=(0, 1, 3)) - logger.info( - 'Removing {} out of {} timesteps due to NaNs'.format( - nan_mask.sum(), self.data.shape[2] - ) - ) + logger.info('Removing {} out of {} timesteps due to NaNs'.format( + nan_mask.sum(), self.data.shape[2])) self.data = self.data[:, :, ~nan_mask, :] + if (self.hr_spatial_coarsen > 1 + and self.lat_lon.shape == self.raw_lat_lon.shape): + self.lat_lon = spatial_coarsening( + self.lat_lon, + s_enhance=self.hr_spatial_coarsen, + obs_axis=False) + logger.info('Finished intializing DataHandler.') log_mem(logger, log_level='INFO') @property def try_load(self): """Check if we should try to load cache""" - return self._should_load_cache( - self._cache_pattern, self.cache_files, self.overwrite_cache) + return self._should_load_cache(self._cache_pattern, self.cache_files, + self.overwrite_cache) def check_clear_data(self): """Check if data is cached and clear data if not load_cached""" @@ -335,20 +313,18 @@ def _run_data_init_if_needed(self): self.data = self.run_all_data_init() nan_perc = 100 * np.isnan(self.data).sum() / self.data.size if nan_perc > 0: - msg = 'Data has {:.2f}% NaN values!'.format(nan_perc) + msg = 'Data has {:.3f}% NaN values!'.format(nan_perc) logger.warning(msg) warnings.warn(msg) def _raster_size_check(self): """Check if the sample_shape is larger than the requested raster size""" - bad_shape = ( - self.sample_shape[0] > self.grid_shape[0] - and self.sample_shape[1] > self.grid_shape[1]) + bad_shape = (self.sample_shape[0] > self.grid_shape[0] + and self.sample_shape[1] > self.grid_shape[1]) if bad_shape: - msg = ( - f'spatial_sample_shape {self.sample_shape[:2]} is ' - f'larger than the raster size {self.grid_shape}') + msg = (f'spatial_sample_shape {self.sample_shape[:2]} is ' + f'larger than the raster size {self.grid_shape}') logger.warning(msg) warnings.warn(msg) @@ -358,16 +334,12 @@ def _val_split_check(self): if self.data is not None and self.val_split > 0.0: self.data, self.val_data = self.split_data( - val_split=self.val_split, shuffle_time=self.shuffle_time - ) - msg = ( - f'Validation data has shape={self.val_data.shape} ' - f'and sample_shape={self.sample_shape}. Use a smaller ' - 'sample_shape and/or larger val_split.') - check = any( - val_size < samp_size - for val_size, samp_size in zip( - self.val_data.shape, self.sample_shape)) + val_split=self.val_split, shuffle_time=self.shuffle_time) + msg = (f'Validation data has shape={self.val_data.shape} ' + f'and sample_shape={self.sample_shape}. Use a smaller ' + 'sample_shape and/or larger val_split.') + check = any(val_size < samp_size for val_size, samp_size in zip( + self.val_data.shape, self.sample_shape)) if check: logger.warning(msg) warnings.warn(msg) @@ -421,8 +393,8 @@ def extract_workers(self): proc_mem /= len(self.time_chunks) n_procs = len(self.time_chunks) * len(self.extract_features) n_procs = int(np.ceil(n_procs)) - extract_workers = estimate_max_workers( - self._extract_workers, proc_mem, n_procs) + extract_workers = estimate_max_workers(self._extract_workers, proc_mem, + n_procs) return extract_workers @property @@ -437,8 +409,8 @@ def compute_workers(self): proc_mem /= len(self.time_chunks) n_procs = len(self.time_chunks) * len(self.derive_features) n_procs = int(np.ceil(n_procs)) - compute_workers = estimate_max_workers( - self._compute_workers, proc_mem, n_procs) + compute_workers = estimate_max_workers(self._compute_workers, proc_mem, + n_procs) return compute_workers @property @@ -457,8 +429,9 @@ def load_workers(self): def norm_workers(self): """Get upper bound on workers used for normalization.""" if self.data is not None: - norm_workers = estimate_max_workers( - self._norm_workers, 2 * self.feature_mem, self.shape[-1]) + norm_workers = estimate_max_workers(self._norm_workers, + 2 * self.feature_mem, + self.shape[-1]) else: norm_workers = self._norm_workers return norm_workers @@ -477,11 +450,9 @@ def time_chunks(self): if self.is_time_independent: self._time_chunks = [slice(None)] else: - self._time_chunks = get_chunk_slices( - len(self.raw_time_index), - self.time_chunk_size, - self.temporal_slice, - ) + self._time_chunks = get_chunk_slices(len(self.raw_time_index), + self.time_chunk_size, + self.temporal_slice) return self._time_chunks @property @@ -508,9 +479,8 @@ def time_chunk_size(self): else: self._time_chunk_size = np.min( [int(1e9 / step_mem), self.n_tsteps]) - logger.info( - 'time_chunk_size arg not specified. Using ' - f'{self._time_chunk_size}.') + logger.info('time_chunk_size arg not specified. Using ' + f'{self._time_chunk_size}.') return self._time_chunk_size @property @@ -576,9 +546,10 @@ def noncached_features(self): def extract_features(self): """Features to extract directly from the source handler""" lower_features = [f.lower() for f in self.handle_features] - return [f for f in self.raw_features - if self.lookup(f, 'compute') is None - or Feature.get_basename(f.lower()) in lower_features] + return [ + f for f in self.raw_features if self.lookup(f, 'compute') is None + or Feature.get_basename(f.lower()) in lower_features + ] @property def derive_features(self): @@ -586,7 +557,8 @@ def derive_features(self): derive_features = [ f for f in set( list(self.noncached_features) + list(self.extract_features)) - if f not in self.extract_features] + if f not in self.extract_features + ] return derive_features @property @@ -656,51 +628,45 @@ def preflight(self): start = self.temporal_slice.start stop = self.temporal_slice.stop n_steps = self.n_tsteps - msg = ( - f'Temporal slice step ({self.temporal_slice.step}) does not ' - f'evenly divide the number of time steps ({n_steps})') + msg = (f'Temporal slice step ({self.temporal_slice.step}) does not ' + f'evenly divide the number of time steps ({n_steps})') check = self.temporal_slice.step is None check = check or n_steps % self.temporal_slice.step == 0 if not check: logger.warning(msg) warnings.warn(msg) - msg = ( - f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger ' - 'than the number of time steps in the raw data ' - f'({len(self.raw_time_index)}).') + msg = (f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger ' + 'than the number of time steps in the raw data ' + f'({len(self.raw_time_index)}).') if len(self.raw_time_index) < self.sample_shape[2]: logger.warning(msg) warnings.warn(msg) - msg = ( - f'The requested time slice {self.temporal_slice} conflicts ' - f'with the number of time steps ({len(self.raw_time_index)}) ' - 'in the raw data') + msg = (f'The requested time slice {self.temporal_slice} conflicts ' + f'with the number of time steps ({len(self.raw_time_index)}) ' + 'in the raw data') t_slice_is_subset = start is not None and stop is not None - good_subset = ( - t_slice_is_subset - and (stop - start <= len(self.raw_time_index)) - and stop <= len(self.raw_time_index) - and start <= len(self.raw_time_index)) + good_subset = (t_slice_is_subset + and (stop - start <= len(self.raw_time_index)) + and stop <= len(self.raw_time_index) + and start <= len(self.raw_time_index)) if t_slice_is_subset and not good_subset: logger.error(msg) raise RuntimeError(msg) - msg = ( - f'Initializing DataHandler {self.input_file_info}. ' - f'Getting temporal range {self.time_index[0]!s} to ' - f'{self.time_index[-1]!s} (inclusive) ' - f'based on temporal_slice {self.temporal_slice}') + msg = (f'Initializing DataHandler {self.input_file_info}. ' + f'Getting temporal range {self.time_index[0]!s} to ' + f'{self.time_index[-1]!s} (inclusive) ' + f'based on temporal_slice {self.temporal_slice}') logger.info(msg) - logger.info( - f'Using max_workers={self.max_workers}, ' - f'norm_workers={self.norm_workers}, ' - f'extract_workers={self.extract_workers}, ' - f'compute_workers={self.compute_workers}, ' - f'load_workers={self.load_workers}, ' - f'ti_workers={self.ti_workers}') + logger.info(f'Using max_workers={self.max_workers}, ' + f'norm_workers={self.norm_workers}, ' + f'extract_workers={self.extract_workers}, ' + f'compute_workers={self.compute_workers}, ' + f'load_workers={self.load_workers}, ' + f'ti_workers={self.ti_workers}') @classmethod def get_lat_lon(cls, file_paths, raster_index, invert_lat=False): @@ -741,12 +707,11 @@ def get_node_cmd(cls, config): initialize DataHandler and run data extraction. """ - import_str = ( - 'from sup3r.preprocessing.data_handling ' - f'import {cls.__name__};\n' - 'import time;\n' - 'from sup3r.pipeline import Status;\n' - 'from rex import init_logger;\n') + import_str = ('from sup3r.preprocessing.data_handling ' + f'import {cls.__name__};\n' + 'import time;\n' + 'from sup3r.pipeline import Status;\n' + 'from rex import init_logger;\n') dh_init_str = get_fun_call_str(cls, config) @@ -763,12 +728,11 @@ def get_node_cmd(cls, config): logger.warning(msg) warnings.warn(msg) - cmd = ( - f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"data_handler = {dh_init_str};\n" - "t_elap = time.time() - t0;\n") + cmd = (f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"data_handler = {dh_init_str};\n" + "t_elap = time.time() - t0;\n") cmd = BaseCLI.add_status_cmd(config, ModuleName.DATA_EXTRACT, cmd) @@ -806,8 +770,8 @@ def get_cache_file_names(self, target = target if target is not None else self.target features = features if features is not None else self.features - return self._get_cache_file_names( - cache_pattern, grid_shape, time_index, target, features) + return self._get_cache_file_names(cache_pattern, grid_shape, + time_index, target, features) def unnormalize(self, means, stds): """Remove normalization from stored means and stds""" @@ -826,8 +790,11 @@ def normalize(self, means, stds): array of means for all features with same ordering as data features """ max_workers = self.norm_workers - self._normalize( - self.data, self.val_data, means, stds, max_workers=max_workers) + self._normalize(self.data, + self.val_data, + means, + stds, + max_workers=max_workers) def get_next(self): """Get data for observation using random observation index. Loops @@ -902,19 +869,17 @@ def cache_data(self, cache_file_paths): cache_file_paths : str | None Path to file for saving feature data """ - self._cache_data( - self.data, self.features, cache_file_paths, self.overwrite_cache - ) + self._cache_data(self.data, self.features, cache_file_paths, + self.overwrite_cache) @property def requested_shape(self): """Get requested shape for cached data""" shape = get_raster_shape(self.raster_index) - requested_shape = ( - shape[0] // self.hr_spatial_coarsen, - shape[1] // self.hr_spatial_coarsen, - len(self.raw_time_index[self.temporal_slice]), - len(self.features)) + requested_shape = (shape[0] // self.hr_spatial_coarsen, + shape[1] // self.hr_spatial_coarsen, + len(self.raw_time_index[self.temporal_slice]), + len(self.features)) return requested_shape def load_cached_data(self, with_split=True): @@ -929,16 +894,14 @@ def load_cached_data(self, with_split=True): logger.info('Called load_cached_data() but self.data is not None') elif self.data is None: - msg = ( - 'Found {} cache files but need {} for features {}! ' - 'These are the cache files that were found: {}'.format( - len(self.cache_files), - len(self.features), - self.features, - self.cache_files)) + msg = ('Found {} cache files but need {} for features {}! ' + 'These are the cache files that were found: {}'.format( + len(self.cache_files), len(self.features), + self.features, self.cache_files)) assert len(self.cache_files) == len(self.features), msg - self.data = np.full(shape=self.requested_shape, fill_value=np.nan, + self.data = np.full(shape=self.requested_shape, + fill_value=np.nan, dtype=np.float32) logger.info(f'Loading cached data from: {self.cache_files}') @@ -952,14 +915,13 @@ def load_cached_data(self, with_split=True): nan_perc = 100 * np.isnan(self.data).sum() / self.data.size if nan_perc > 0: - msg = 'Data has {:.2f}% NaN values!'.format(nan_perc) + msg = 'Data has {:.3f}% NaN values!'.format(nan_perc) logger.warning(msg) warnings.warn(msg) - logger.debug( - 'Splitting data into training / validation sets ' - f'({1 - self.val_split}, {self.val_split}) ' - f'for {self.input_file_info}') + logger.debug('Splitting data into training / validation sets ' + f'({1 - self.val_split}, {self.val_split}) ' + f'for {self.input_file_info}') if with_split: self.data, self.val_data = self.split_data( @@ -984,8 +946,8 @@ def run_all_data_init(self): shifted_time_chunks = [slice(None)] else: n_steps = len(self.raw_time_index[self.temporal_slice]) - shifted_time_chunks = get_chunk_slices( - n_steps, self.time_chunk_size) + shifted_time_chunks = get_chunk_slices(n_steps, + self.time_chunk_size) self.run_data_extraction() self.run_data_compute() @@ -1002,11 +964,9 @@ def run_all_data_init(self): if self.hr_spatial_coarsen > 1: logger.debug('Applying hr spatial coarsening to data array') - self.data = spatial_coarsening( - self.data, s_enhance=self.hr_spatial_coarsen, obs_axis=False) - self.lat_lon = spatial_coarsening( - self.lat_lon, s_enhance=self.hr_spatial_coarsen, - obs_axis=False) + self.data = spatial_coarsening(self.data, + s_enhance=self.hr_spatial_coarsen, + obs_axis=False) if self.load_cached: for f in self.cached_features: f_index = self.features.index(f) @@ -1023,9 +983,8 @@ def run_data_extraction(self): un-manipulated datasets. """ if self.extract_features: - logger.info( - f'Starting extraction of {self.extract_features} ' - f'using {len(self.time_chunks)} time_chunks.') + logger.info(f'Starting extraction of {self.extract_features} ' + f'using {len(self.time_chunks)} time_chunks.') if self.extract_workers == 1: self._raw_data = self.serial_extract(self.file_paths, self.raster_index, @@ -1041,10 +1000,8 @@ def run_data_extraction(self): self.extract_workers, **self.res_kwargs) - logger.info( - f'Finished extracting {self.extract_features} for ' - f'{self.input_file_info}' - ) + logger.info(f'Finished extracting {self.extract_features} for ' + f'{self.input_file_info}') def run_data_compute(self): """Run the data computation / derivation from raw features to desired @@ -1054,27 +1011,20 @@ def run_data_compute(self): logger.info(f'Starting computation of {self.derive_features}') if self.compute_workers == 1: - self._raw_data = self.serial_compute(self._raw_data, - self.file_paths, - self.raster_index, - self.time_chunks, - self.derive_features, - self.noncached_features, - self.handle_features) + self._raw_data = self.serial_compute( + self._raw_data, self.file_paths, self.raster_index, + self.time_chunks, self.derive_features, + self.noncached_features, self.handle_features) elif self.compute_workers != 1: - self._raw_data = self.parallel_compute(self._raw_data, - self.file_paths, - self.raster_index, - self.time_chunks, - self.derive_features, - self.noncached_features, - self.handle_features, - self.compute_workers) + self._raw_data = self.parallel_compute( + self._raw_data, self.file_paths, self.raster_index, + self.time_chunks, self.derive_features, + self.noncached_features, self.handle_features, + self.compute_workers) - logger.info( - f'Finished computing {self.derive_features} for ' - f'{self.input_file_info}') + logger.info(f'Finished computing {self.derive_features} for ' + f'{self.input_file_info}') def data_fill(self, t, t_slice, f_index, f): """Place single extracted / computed chunk in final data array @@ -1110,9 +1060,8 @@ def serial_data_fill(self, shifted_time_chunks): self.data_fill(t, ts, f_index, f) interval = int(np.ceil(len(shifted_time_chunks) / 10)) if t % interval == 0: - logger.info( - f'Added {t + 1} of {len(shifted_time_chunks)} ' - 'chunks to final data array') + logger.info(f'Added {t + 1} of {len(shifted_time_chunks)} ' + 'chunks to final data array') self._raw_data.pop(t) def parallel_data_fill(self, shifted_time_chunks, max_workers=None): @@ -1128,11 +1077,10 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None): max available workers will be used. If 1 cached data will be loaded in serial """ - self.data = np.zeros((self.grid_shape[0], - self.grid_shape[1], - self.n_tsteps, - len(self.features)), - dtype=np.float32) + self.data = np.zeros( + (self.grid_shape[0], self.grid_shape[1], self.n_tsteps, + len(self.features)), + dtype=np.float32) if max_workers == 1: self.serial_data_fill(shifted_time_chunks) @@ -1146,25 +1094,22 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None): future = exe.submit(self.data_fill, t, ts, f_index, f) futures[future] = {'t': t, 'fidx': f_index} - logger.info( - f'Started adding {len(futures)} chunks ' - f'to data array in {dt.now() - now}.') + logger.info(f'Started adding {len(futures)} chunks ' + f'to data array in {dt.now() - now}.') interval = int(np.ceil(len(futures) / 10)) for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ( - f'Error adding ({futures[future]["t"]}, ' - f'{futures[future]["fidx"]}) chunk to ' - 'final data array.') + msg = (f'Error adding ({futures[future]["t"]}, ' + f'{futures[future]["fidx"]}) chunk to ' + 'final data array.') logger.exception(msg) raise RuntimeError(msg) from e if i % interval == 0: - logger.debug( - f'Added {i+1} out of {len(futures)} ' - 'chunks to final data array') + logger.debug(f'Added {i+1} out of {len(futures)} ' + 'chunks to final data array') logger.info('Finished building data array') @abstractmethod @@ -1221,18 +1166,16 @@ def lin_bc(self, bc_files, threshold=0.1): scalar = scalar[..., idm] adder = adder[..., idm] else: - msg = ( - 'Can only accept bias correction factors ' - 'with last dim equal to 1 or 12 but ' - 'received bias correction factors with ' - 'shape {}'.format(scalar.shape)) + msg = ('Can only accept bias correction factors ' + 'with last dim equal to 1 or 12 but ' + 'received bias correction factors with ' + 'shape {}'.format(scalar.shape)) logger.error(msg) raise RuntimeError(msg) - logger.info( - 'Bias correcting "{}" with linear ' - 'correction from "{}"'.format( - feature, os.path.basename(fp))) + logger.info('Bias correcting "{}" with linear ' + 'correction from "{}"'.format( + feature, os.path.basename(fp))) self.data[..., idf] *= scalar self.data[..., idf] += adder completed.append(feature) @@ -1242,7 +1185,8 @@ def lin_bc(self, bc_files, threshold=0.1): class DataHandlerDC(DataHandler): """Data-centric data handler""" - def get_observation_index(self, temporal_weights=None, + def get_observation_index(self, + temporal_weights=None, spatial_weights=None): """Randomly gets weighted spatial sample and time sample @@ -1262,20 +1206,23 @@ def get_observation_index(self, temporal_weights=None, Used to get single observation like self.data[observation_index] """ if spatial_weights is not None: - spatial_slice = weighted_box_sampler( - self.data, self.sample_shape[:2], weights=spatial_weights) + spatial_slice = weighted_box_sampler(self.data, + self.sample_shape[:2], + weights=spatial_weights) else: - spatial_slice = uniform_box_sampler( - self.data, self.sample_shape[:2]) + spatial_slice = uniform_box_sampler(self.data, + self.sample_shape[:2]) if temporal_weights is not None: - temporal_slice = weighted_time_sampler( - self.data, self.sample_shape[2], weights=temporal_weights) + temporal_slice = weighted_time_sampler(self.data, + self.sample_shape[2], + weights=temporal_weights) else: - temporal_slice = uniform_time_sampler( - self.data, self.sample_shape[2]) + temporal_slice = uniform_time_sampler(self.data, + self.sample_shape[2]) return tuple( - [*spatial_slice, temporal_slice, np.arange(len(self.features))]) + [*spatial_slice, temporal_slice, + np.arange(len(self.features))]) def get_next(self, temporal_weights=None, spatial_weights=None): """Get data for observation using weighted random observation index. @@ -1331,9 +1278,8 @@ def get_file_times(cls, file_paths, **kwargs): elif hasattr(handle, 'times'): time_index = np_to_pd_times(handle.times.values) else: - msg = ( - f'Could not get time_index for {file_paths}. ' - 'Assuming time independence.') + msg = (f'Could not get time_index for {file_paths}. ' + 'Assuming time independence.') time_index = None logger.warning(msg) warnings.warn(msg) @@ -1361,9 +1307,8 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): time_index : pd.Datetimeindex List of times as a Datetimeindex """ - max_workers = (len(file_paths) - if max_workers is None - else np.min((max_workers, len(file_paths)))) + max_workers = (len(file_paths) if max_workers is None else np.min( + (max_workers, len(file_paths)))) if max_workers == 1: return cls.get_file_times(file_paths, **kwargs) ti = {} @@ -1374,9 +1319,8 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): future = exe.submit(cls.get_file_times, [f], **kwargs) futures[future] = {'idx': i, 'file': f} - logger.info( - f'Started building time index from {len(file_paths)} ' - f'files in {dt.now() - now}.') + logger.info(f'Started building time index from {len(file_paths)} ' + f'files in {dt.now() - now}.') for i, future in enumerate(as_completed(futures)): try: @@ -1401,21 +1345,23 @@ def feature_registry(cls): dict Method registry """ - registry = {'BVF2_(.*)m': BVFreqSquaredNC, - 'BVF_MO_(.*)m': BVFreqMon, - 'RMOL': InverseMonNC, - 'U_(.*)': UWind, - 'V_(.*)': VWind, - 'Windspeed_(.*)m': WindspeedNC, - 'Winddirection_(.*)m': WinddirectionNC, - 'lat_lon': LatLonNC, - 'Shear_(.*)m': Shear, - 'REWS_(.*)m': Rews, - 'Temperature_(.*)m': TempNC, - 'Pressure_(.*)m': PressureNC, - 'PotentialTemp_(.*)m': PotentialTempNC, - 'PT_(.*)m': PotentialTempNC, - 'topography': ['HGT', 'orog']} + registry = { + 'BVF2_(.*)m': BVFreqSquaredNC, + 'BVF_MO_(.*)m': BVFreqMon, + 'RMOL': InverseMonNC, + 'U_(.*)': UWind, + 'V_(.*)': VWind, + 'Windspeed_(.*)m': WindspeedNC, + 'Winddirection_(.*)m': WinddirectionNC, + 'lat_lon': LatLonNC, + 'Shear_(.*)m': Shear, + 'REWS_(.*)m': Rews, + 'Temperature_(.*)m': TempNC, + 'Pressure_(.*)m': PressureNC, + 'PotentialTemp_(.*)m': PotentialTempNC, + 'PT_(.*)m': PotentialTempNC, + 'topography': ['HGT', 'orog'] + } return registry @classmethod @@ -1453,10 +1399,8 @@ def extract_feature(cls, Data array for extracted feature (spatial_1, spatial_2, temporal) """ - logger.debug( - f'Extracting {feature} with time_slice={time_slice}, ' - f'raster_index={raster_index}, kwargs={kwargs}.' - ) + logger.debug(f'Extracting {feature} with time_slice={time_slice}, ' + f'raster_index={raster_index}, kwargs={kwargs}.') handle = cls.source_handler(file_paths, **kwargs) f_info = Feature(feature, handle) interp_height = f_info.height @@ -1465,25 +1409,19 @@ def extract_feature(cls, if feature in handle or feature.lower() in handle: feat_key = feature if feature in handle else feature.lower() - fdata = cls.direct_extract( - handle, feat_key, raster_index, time_slice) + fdata = cls.direct_extract(handle, feat_key, raster_index, + time_slice) elif basename in handle or basename.lower() in handle: feat_key = basename if basename in handle else basename.lower() if interp_height is not None: fdata = Interpolator.interp_var_to_height( - handle, - feat_key, - raster_index, - np.float32(interp_height), + handle, feat_key, raster_index, np.float32(interp_height), time_slice) elif interp_pressure is not None: fdata = Interpolator.interp_var_to_pressure( - handle, - feat_key, - raster_index, - np.float32(interp_pressure), - time_slice) + handle, feat_key, raster_index, + np.float32(interp_pressure), time_slice) else: msg = f'{feature} cannot be extracted from source data.' @@ -1568,12 +1506,12 @@ def get_closest_lat_lon(lat_lon, target): col index for closest lat/lon to target lat/lon """ # shape of ll2 is (n, 2) where axis=1 is (lat, lon) - ll2 = np.vstack( - (lat_lon[..., 0].flatten(), lat_lon[..., 1].flatten())).T + ll2 = np.vstack((lat_lon[..., 0].flatten(), lat_lon[..., + 1].flatten())).T tree = KDTree(ll2) _, i = tree.query(np.array(target)) - row, col = np.where( - (lat_lon[..., 0] == ll2[i, 0]) & (lat_lon[..., 1] == ll2[i, 1])) + row, col = np.where((lat_lon[..., 0] == ll2[i, 0]) + & (lat_lon[..., 1] == ll2[i, 1])) row = row[0] col = col[0] return row, col @@ -1596,9 +1534,9 @@ def compute_raster_index(cls, file_paths, target, grid_shape): list List of slices corresponding to extracted data region """ - lat_lon = cls.get_lat_lon( - file_paths[:1], [slice(None), slice(None)], invert_lat=False - ) + lat_lon = cls.get_lat_lon(file_paths[:1], + [slice(None), slice(None)], + invert_lat=False) cls._check_grid_extent(target, grid_shape, lat_lon) row, col = cls.get_closest_lat_lon(lat_lon, target) @@ -1616,8 +1554,10 @@ def compute_raster_index(cls, file_paths, target, grid_shape): else: row_end = row + grid_shape[0] row_start = row - raster_index = [slice(row_start, row_end), - slice(col, col + grid_shape[1])] + raster_index = [ + slice(row_start, row_end), + slice(col, col + grid_shape[1]) + ] cls._validate_raster_shape(target, grid_shape, lat_lon, raster_index) return raster_index @@ -1641,15 +1581,12 @@ def _check_grid_extent(cls, target, grid_shape, lat_lon): min_lon = np.min(lat_lon[..., 1]) max_lat = np.max(lat_lon[..., 0]) max_lon = np.max(lat_lon[..., 1]) - logger.debug( - 'Calculating raster index from WRF file ' - f'for shape {grid_shape} and target {target}') - logger.debug( - f'lat/lon (min, max): {min_lat}/{min_lon}, ' - f'{max_lat}/{max_lon}') - msg = ( - f'target {target} out of bounds with min lat/lon ' - f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}') + logger.debug('Calculating raster index from WRF file ' + f'for shape {grid_shape} and target {target}') + logger.debug(f'lat/lon (min, max): {min_lat}/{min_lon}, ' + f'{max_lat}/{max_lon}') + msg = (f'target {target} out of bounds with min lat/lon ' + f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}') assert (min_lat <= target[0] <= max_lat and min_lon <= target[1] <= max_lon), msg @@ -1673,16 +1610,14 @@ def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): """ check = (raster_index[0].stop > lat_lon.shape[0] or raster_index[1].stop > lat_lon.shape[1] - or raster_index[0].start < 0 - or raster_index[1].start < 0) + or raster_index[0].start < 0 or raster_index[1].start < 0) if check: - msg = ( - f'Invalid target {target}, shape {grid_shape}, and raster ' - f'{raster_index} for data domain of size ' - f'{lat_lon.shape[:-1]} with lower left corner ' - f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' - f' and upper right corner ({np.max(lat_lon[..., 0])}, ' - f'{np.max(lat_lon[..., 1])}).') + msg = (f'Invalid target {target}, shape {grid_shape}, and raster ' + f'{raster_index} for data domain of size ' + f'{lat_lon.shape[:-1]} with lower left corner ' + f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' + f' and upper right corner ({np.max(lat_lon[..., 0])}, ' + f'{np.max(lat_lon[..., 1])}).') raise ValueError(msg) def get_raster_index(self): @@ -1695,27 +1630,23 @@ def get_raster_index(self): raster_index : np.ndarray 2D array of grid indices """ - self.raster_file = ( - self.raster_file - if self.raster_file is None - else self.raster_file.replace('.txt', '.npy')) + self.raster_file = (self.raster_file if self.raster_file is None else + self.raster_file.replace('.txt', '.npy')) if self.raster_file is not None and os.path.exists(self.raster_file): - logger.debug( - f'Loading raster index: {self.raster_file} ' - f'for {self.input_file_info}') + logger.debug(f'Loading raster index: {self.raster_file} ' + f'for {self.input_file_info}') raster_index = np.load(self.raster_file, allow_pickle=True) raster_index = list(raster_index) else: check = self.grid_shape is not None and self.target is not None - msg = ( - 'Must provide raster file or shape + target to get ' - 'raster index') + msg = ('Must provide raster file or shape + target to get ' + 'raster index') assert check, msg - raster_index = self.compute_raster_index( - self.file_paths, self.target, self.grid_shape) - logger.debug( - 'Found raster index with row, col slices: {}'.format( - raster_index)) + raster_index = self.compute_raster_index(self.file_paths, + self.target, + self.grid_shape) + logger.debug('Found raster index with row, col slices: {}'.format( + raster_index)) if self.raster_file is not None: basedir = os.path.dirname(self.raster_file) diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 70bbd4af8..522edd70e 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -5,10 +5,9 @@ import numpy as np import pandas as pd -from sup3r.preprocessing.data_handling.mixin import ( - CacheHandlingMixIn, - TrainingPrepMixIn, -) +from sup3r.preprocessing.data_handling.mixin import (CacheHandlingMixIn, + TrainingPrepMixIn, + ) from sup3r.utilities.regridder import Regridder from sup3r.utilities.utilities import spatial_coarsening @@ -178,15 +177,16 @@ def _shape_check(self): logger.info("Loading high resolution cache.") self.hr_dh.load_cached_data(with_split=False) - msg = ('hr_handler.shape is not divisible by s_enhance. Using ' - f'shape = {self.hr_required_shape} instead.') + msg = (f'hr_handler.shape {self.hr_dh.shape[:-1]} is not divisible ' + f'by s_enhance. Using shape = {self.hr_required_shape} ' + 'instead.') if self.hr_dh.shape[:-1] != self.hr_required_shape: logger.warning(msg) warn(msg) - self.hr_data = self.hr_dh.data[:self.hr_required_shape[0], - :self.hr_required_shape[1], - :self.hr_required_shape[2]] + self.hr_data = self.hr_dh.data[:self.hr_required_shape[0], :self. + hr_required_shape[1], :self. + hr_required_shape[2]] self.hr_time_index = self.hr_dh.time_index[:self.hr_required_shape[2]] self.lr_time_index = self.lr_dh.time_index[:self.lr_required_shape[2]] @@ -255,7 +255,7 @@ def feature_mem(self): int Number of bytes for a single feature array """ - feature_mem = self.grid_mem * len(self.lr_time_index) + feature_mem = self.grid_mem * self.lr_data.shape[-2] return feature_mem @property @@ -294,7 +294,7 @@ def lr_input_data(self): @property def shape(self): """Get low_res shape""" - return self.lr_dh.shape + return self.lr_data.shape @property def lr_required_shape(self): @@ -338,6 +338,7 @@ def lr_lat_lon(self, lat_lon): def hr_lat_lon(self): """Get high_res lat lon array""" if self._hr_lat_lon is None: + self._hr_lat_lon = self.hr_dh.lat_lon[:self.hr_required_shape[0], : self.hr_required_shape[1]] return self._hr_lat_lon diff --git a/sup3r/preprocessing/dual_batch_handling.py b/sup3r/preprocessing/dual_batch_handling.py index a6a1d34d4..ac1901d9a 100644 --- a/sup3r/preprocessing/dual_batch_handling.py +++ b/sup3r/preprocessing/dual_batch_handling.py @@ -3,11 +3,9 @@ import numpy as np -from sup3r.preprocessing.batch_handling import ( - Batch, - BatchHandler, - ValidationData, -) +from sup3r.preprocessing.batch_handling import (Batch, BatchHandler, + ValidationData, + ) from sup3r.utilities.utilities import uniform_box_sampler, uniform_time_sampler logger = logging.getLogger(__name__) @@ -46,22 +44,18 @@ def _get_val_indices(self): hr_index = [] for s in lr_index[:2]: hr_index.append( - slice( - s.start * self.s_enhance, - s.stop * self.s_enhance, - )) + slice(s.start * self.s_enhance, + s.stop * self.s_enhance)) for s in lr_index[2:-1]: hr_index.append( - slice( - s.start * self.t_enhance, - s.stop * self.t_enhance, - )) + slice(s.start * self.t_enhance, + s.stop * self.t_enhance)) hr_index.append(lr_index[-1]) hr_index = tuple(hr_index) val_indices.append({ 'handler_index': i, 'hr_index': hr_index, - 'lr_index': lr_index, + 'lr_index': lr_index }) return val_indices @@ -79,12 +73,9 @@ def shape(self): time_steps = 0 for h in self.data_handlers: time_steps += h.hr_val_data.shape[2] - return ( - self.data_handlers[0].hr_val_data.shape[0], - self.data_handlers[0].hr_val_data.shape[1], - time_steps, - self.data_handlers[0].hr_val_data.shape[3], - ) + return (self.data_handlers[0].hr_val_data.shape[0], + self.data_handlers[0].hr_val_data.shape[1], time_steps, + self.data_handlers[0].hr_val_data.shape[3]) @property def hr_sample_shape(self): @@ -113,23 +104,13 @@ def __next__(self): n_obs = self._remaining_observations high_res = np.zeros( - ( - n_obs, - self.hr_sample_shape[0], - self.hr_sample_shape[1], - self.hr_sample_shape[2], - self.data_handlers[0].shape[-1], - ), + (n_obs, self.hr_sample_shape[0], self.hr_sample_shape[1], + self.hr_sample_shape[2], self.data_handlers[0].shape[-1]), dtype=np.float32, ) low_res = np.zeros( - ( - n_obs, - self.lr_sample_shape[0], - self.lr_sample_shape[1], - self.lr_sample_shape[2], - self.data_handlers[0].shape[-1], - ), + (n_obs, self.lr_sample_shape[0], self.lr_sample_shape[1], + self.lr_sample_shape[2], self.data_handlers[0].shape[-1]), dtype=np.float32, ) for i in range(high_res.shape[0]): @@ -190,26 +171,14 @@ def __next__(self): handler_index = np.random.randint(0, len(self.data_handlers)) self.current_handler_index = handler_index handler = self.data_handlers[handler_index] - high_res = np.zeros( - ( - self.batch_size, - self.hr_sample_shape[0], - self.hr_sample_shape[1], - self.hr_sample_shape[2], - self.shape[-1], - ), - dtype=np.float32, - ) - low_res = np.zeros( - ( - self.batch_size, - self.lr_sample_shape[0], - self.lr_sample_shape[1], - self.lr_sample_shape[2], - self.shape[-1], - ), - dtype=np.float32, - ) + high_res = np.zeros((self.batch_size, self.hr_sample_shape[0], + self.hr_sample_shape[1], + self.hr_sample_shape[2], self.shape[-1]), + dtype=np.float32) + low_res = np.zeros((self.batch_size, self.lr_sample_shape[0], + self.lr_sample_shape[1], + self.lr_sample_shape[2], self.shape[-1]), + dtype=np.float32) for i in range(self.batch_size): high_res[i, ...], low_res[i, ...] = handler.get_next() @@ -250,24 +219,12 @@ def __next__(self): handler_index = np.random.randint(0, len(self.data_handlers)) self.current_handler_index = handler_index handler = self.data_handlers[handler_index] - high_res = np.zeros( - ( - self.batch_size, - self.hr_sample_shape[0], - self.hr_sample_shape[1], - self.shape[-1], - ), - dtype=np.float32, - ) - low_res = np.zeros( - ( - self.batch_size, - self.lr_sample_shape[0], - self.lr_sample_shape[1], - self.shape[-1], - ), - dtype=np.float32, - ) + high_res = np.zeros((self.batch_size, self.hr_sample_shape[0], + self.hr_sample_shape[1], self.shape[-1]), + dtype=np.float32) + low_res = np.zeros((self.batch_size, self.lr_sample_shape[0], + self.lr_sample_shape[1], self.shape[-1]), + dtype=np.float32) for i in range(self.batch_size): hr, lr = handler.get_next() diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index 97bad45ea..003cf51d4 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -29,16 +29,14 @@ class Regridder: MIN_DISTANCE = 1e-12 - def __init__( - self, - source_meta, - target_meta, - cache_pattern=None, - leaf_size=4, - k_neighbors=4, - n_chunks=100, - max_workers=None, - ): + def __init__(self, + source_meta, + target_meta, + cache_pattern=None, + leaf_size=4, + k_neighbors=4, + n_chunks=100, + max_workers=None): """Get weights and indices used to map from source grid to target grid Parameters @@ -86,16 +84,14 @@ def __init__( self.cache_all_queries() @classmethod - def run( - cls, - source_meta, - target_meta, - cache_pattern=None, - leaf_size=4, - k_neighbors=4, - n_chunks=100, - max_workers=None, - ): + def run(cls, + source_meta, + target_meta, + cache_pattern=None, + leaf_size=4, + k_neighbors=4, + n_chunks=100, + max_workers=None): """Query tree for every point in target_meta to get full set of indices and distances for the neighboring points in the source_meta. @@ -122,15 +118,13 @@ def run( to building full set of indices and distances for each target_meta coordinate. """ - regridder = cls( - source_meta=source_meta, - target_meta=target_meta, - cache_pattern=cache_pattern, - leaf_size=leaf_size, - k_neighbors=k_neighbors, - n_chunks=n_chunks, - max_workers=max_workers, - ) + regridder = cls(source_meta=source_meta, + target_meta=target_meta, + cache_pattern=cache_pattern, + leaf_size=leaf_size, + k_neighbors=k_neighbors, + n_chunks=n_chunks, + max_workers=max_workers) if not regridder.cache_exists: regridder.get_all_queries(max_workers) regridder.cache_all_queries() @@ -142,10 +136,8 @@ def weights(self): dists = np.array(self.distances, dtype=np.float32) mask = dists < self.MIN_DISTANCE if mask.sum() > 0: - logger.info( - f'{np.sum(mask)} of {np.product(mask.shape)} ' - 'distances are zero.' - ) + logger.info(f'{np.sum(mask)} of {np.product(mask.shape)} ' + 'distances are zero.') dists[mask] = self.MIN_DISTANCE weights = 1 / dists self._weights = weights / np.sum(weights, axis=-1)[:, None] @@ -154,12 +146,10 @@ def weights(self): @property def cache_exists(self): """Check if cache exists before building tree.""" - cache_exists_check = ( - self.index_file is not None - and os.path.exists(self.index_file) - and self.distance_file is not None - and os.path.exists(self.distance_file) - ) + cache_exists_check = (self.index_file is not None + and os.path.exists(self.index_file) + and self.distance_file is not None + and os.path.exists(self.distance_file)) return cache_exists_check def build_tree(self): @@ -200,13 +190,10 @@ def _parallel_queries(self, max_workers=None): future = exe.submit(self.save_query, s_slice=s_slice) futures[future] = i mem = psutil.virtual_memory() - msg = ( - 'Query futures submitted: {} out of {}. Current ' - 'memory usage is {:.3f} GB out of {:.3f} GB ' - 'total.'.format( - i + 1, len(slices), mem.used / 1e9, mem.total / 1e9 - ) - ) + msg = ('Query futures submitted: {} out of {}. Current ' + 'memory usage is {:.3f} GB out of {:.3f} GB ' + 'total.'.format(i + 1, len(slices), mem.used / 1e9, + mem.total / 1e9)) logger.info(msg) logger.info(f'Submitted all query futures in {dt.now() - now}.') @@ -214,21 +201,17 @@ def _parallel_queries(self, max_workers=None): for i, future in enumerate(as_completed(futures)): idx = futures[future] mem = psutil.virtual_memory() - msg = ( - 'Query futures completed: {} out of ' - '{}. Current memory usage is {:.3f} ' - 'GB out of {:.3f} GB total.'.format( - i + 1, len(futures), mem.used / 1e9, mem.total / 1e9 - ) - ) + msg = ('Query futures completed: {} out of ' + '{}. Current memory usage is {:.3f} ' + 'GB out of {:.3f} GB total.'.format( + i + 1, len(futures), mem.used / 1e9, + mem.total / 1e9)) logger.info(msg) try: future.result() except Exception as e: - msg = ( - 'Failed to query coordinate chunk with ' - 'index={index}'.format(index=idx) - ) + msg = ('Failed to query coordinate chunk with ' + 'index={index}'.format(index=idx)) logger.exception(msg) raise RuntimeError(msg) from e @@ -244,9 +227,8 @@ def load_cache(self): self.indices = pickle.load(f) with open(self.distance_file, 'rb') as f: self.distances = pickle.load(f) - logger.info( - f'Loaded cache files: {self.index_file}, ' f'{self.distance_file}' - ) + logger.info(f'Loaded cache files: {self.index_file}, ' + f'{self.distance_file}') def cache_all_queries(self): """Cache indices and distances from ball tree query""" @@ -255,10 +237,8 @@ def cache_all_queries(self): pickle.dump(self.indices, f, protocol=4) with open(self.distance_file, 'wb') as f: pickle.dump(self.distances, f, protocol=4) - logger.info( - f'Saved cache files: {self.index_file}, ' - f'{self.distance_file}' - ) + logger.info(f'Saved cache files: {self.index_file}, ' + f'{self.distance_file}') @property def index_file(self): @@ -313,9 +293,8 @@ def query_tree(self, s_slice): Array of indices for neighboring points for each point selected by s_slice. (n_ponts, k_neighbors) """ - return self.tree.query( - self.get_spatial_chunk(s_slice), k=self.k_neighbors - ) + return self.tree.query(self.get_spatial_chunk(s_slice), + k=self.k_neighbors) @classmethod def interpolate(cls, distance_chunk, values): @@ -342,10 +321,8 @@ def interpolate(cls, distance_chunk, values): dists = np.array(distance_chunk, dtype=np.float32) mask = dists < cls.MIN_DISTANCE if mask.sum() > 0: - logger.info( - f'{np.sum(mask)} of {np.product(mask.shape)} ' - 'distances are zero.' - ) + logger.info(f'{np.sum(mask)} of {np.product(mask.shape)} ' + 'distances are zero.') dists[mask] = cls.MIN_DISTANCE weights = 1 / dists norm = np.sum(weights, axis=-1) @@ -367,13 +344,11 @@ def __call__(self, data): Flattened regridded spatiotemporal data (spatial, temporal) """ - vals = np.concatenate( - [ - data[:, :, i].flatten()[self.indices][np.newaxis] - for i in range(data.shape[-1]) - ], - axis=0, - ) + vals = [ + data[:, :, i].flatten()[self.indices][np.newaxis] + for i in range(data.shape[-1]) + ] + vals = np.concatenate(vals, axis=0) out = np.einsum('ijk,jk->ij', vals, self.weights) return out.T @@ -406,11 +381,9 @@ def get_source_values(cls, index_chunk, feature, source_files): (temporal, n_points, k_neighbors) """ with MultiFileResource(source_files) as res: - shape = ( - len(res.time_index), - len(index_chunk), - len(index_chunk[0]), - ) + shape = (len(res.time_index), len(index_chunk), + len(index_chunk[0]), + ) tmp = np.array(index_chunk).flatten() out = res[feature, :, tmp] out = out.reshape(shape) @@ -441,12 +414,10 @@ def get_source_uv(cls, index_chunk, height, source_files): Array of meridional wind values to use for interpolation with shape (temporal, n_points, k_neighbors) """ - ws = cls.get_source_values( - index_chunk, f'windspeed_{height}m', source_files - ) - wd = cls.get_source_values( - index_chunk, f'winddirection_{height}m', source_files - ) + ws = cls.get_source_values(index_chunk, f'windspeed_{height}m', + source_files) + wd = cls.get_source_values(index_chunk, f'winddirection_{height}m', + source_files) u = ws * np.sin(np.radians(wd)) v = ws * np.cos(np.radians(wd)) @@ -480,9 +451,8 @@ def invert_uv(cls, u, v): return ws, wd @classmethod - def regrid_coordinates( - cls, index_chunk, distance_chunk, height, source_files - ): + def regrid_coordinates(cls, index_chunk, distance_chunk, height, + source_files): """Regrid wind fields at given height for the requested coordinate index @@ -525,20 +495,19 @@ class RegridOutput(OutputMixIn, DistributedProcess): a new target grid. The interpolated data is then written to new files, with one file for each field (e.g. windspeed_100m).""" - def __init__( - self, - source_files, - out_pattern, - target_meta, - heights, - cache_pattern=None, - leaf_size=4, - k_neighbors=4, - incremental=False, - n_chunks=100, - max_nodes=1, - worker_kwargs=None, - ): + def __init__(self, + source_files, + out_pattern, + target_meta, + heights, + cache_pattern=None, + leaf_size=4, + k_neighbors=4, + incremental=False, + n_chunks=100, + max_nodes=1, + worker_kwargs=None, + ): """ Parameters ---------- @@ -575,17 +544,13 @@ def __init__( worker_kwargs = worker_kwargs or {} self.regrid_workers = worker_kwargs.get('regrid_workers', None) self.query_workers = worker_kwargs.get('query_workers', None) - self.source_files = ( - source_files - if isinstance(source_files, list) - else glob(source_files) - ) + self.source_files = (source_files if isinstance(source_files, list) + else glob(source_files)) self.target_meta_path = target_meta self.target_meta = pd.read_csv(self.target_meta_path) self.target_meta['gid'] = np.arange(len(self.target_meta)) self.target_meta = self.target_meta.sort_values( - ['latitude', 'longitude'], ascending=[False, True] - ) + ['latitude', 'longitude'], ascending=[False, True]) self.heights = heights self.incremental = incremental self.out_pattern = out_pattern @@ -596,32 +561,26 @@ def __init__( self.source_meta = res.meta self.global_attrs = res.global_attrs - self.regridder = WindRegridder( - self.source_meta, - self.target_meta, - leaf_size=leaf_size, - k_neighbors=k_neighbors, - cache_pattern=cache_pattern, - n_chunks=n_chunks, - max_workers=self.query_workers, - ) - DistributedProcess.__init__( - self, - max_nodes=max_nodes, - n_chunks=n_chunks, - max_chunks=len(self.regridder.indices), - incremental=incremental, - ) - - logger.info( - 'Initializing RegridOutput with ' - f'source_files={self.source_files}, ' - f'out_pattern={self.out_pattern}, ' - f'heights={self.heights}, ' - f'target_meta={target_meta}, ' - f'k_neighbors={k_neighbors}, and ' - f'n_chunks={n_chunks}.' - ) + self.regridder = WindRegridder(self.source_meta, + self.target_meta, + leaf_size=leaf_size, + k_neighbors=k_neighbors, + cache_pattern=cache_pattern, + n_chunks=n_chunks, + max_workers=self.query_workers) + DistributedProcess.__init__(self, + max_nodes=max_nodes, + n_chunks=n_chunks, + max_chunks=len(self.regridder.indices), + incremental=incremental) + + logger.info('Initializing RegridOutput with ' + f'source_files={self.source_files}, ' + f'out_pattern={self.out_pattern}, ' + f'heights={self.heights}, ' + f'target_meta={target_meta}, ' + f'k_neighbors={k_neighbors}, and ' + f'n_chunks={n_chunks}.') logger.info(f'Max memory usage: {self.max_memory:.3f} GB.') @property @@ -688,12 +647,10 @@ def get_node_cmd(cls, config): run regridding. """ - import_str = ( - 'from sup3r.utilities.regridder import RegridOutput;\n' - 'from rex import init_logger;\n' - 'import time;\n' - 'from sup3r.pipeline import Status;\n' - ) + import_str = ('from sup3r.utilities.regridder import RegridOutput;\n' + 'from rex import init_logger;\n' + 'import time;\n' + 'from sup3r.pipeline import Status;\n') regrid_fun_str = get_fun_call_str(cls, config) @@ -704,14 +661,12 @@ def get_node_cmd(cls, config): if log_file is not None: log_arg_str += f', log_file="{log_file}"' - cmd = ( - f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"regrid_output = {regrid_fun_str};\n" - f"regrid_output.run({node_index});\n" - "t_elap = time.time() - t0;\n" - ) + cmd = (f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"regrid_output = {regrid_fun_str};\n" + f"regrid_output.run({node_index});\n" + "t_elap = time.time() - t0;\n") cmd = BaseCLI.add_status_cmd(config, ModuleName.REGRID, cmd) cmd += ";\'\n" @@ -731,15 +686,13 @@ def run(self, node_index): return if self.regrid_workers == 1: - self._run_serial( - source_files=self.source_files, node_index=node_index - ) + self._run_serial(source_files=self.source_files, + node_index=node_index) else: - self._run_parallel( - source_files=self.source_files, - node_index=node_index, - max_workers=self.regrid_workers, - ) + self._run_parallel(source_files=self.source_files, + node_index=node_index, + max_workers=self.regrid_workers, + ) def _run_serial(self, source_files, node_index): """Regrid data and write to output file, in serial. @@ -754,21 +707,15 @@ def _run_serial(self, source_files, node_index): """ logger.info('Regridding all coordinates in serial.') for i, chunk_index in enumerate(self.node_chunks[node_index]): - self.write_coordinates( - source_files=source_files, chunk_index=chunk_index - ) + self.write_coordinates(source_files=source_files, + chunk_index=chunk_index) mem = psutil.virtual_memory() - msg = ( - 'Coordinate chunks regridded: {} out of {}. ' - 'Current memory usage is {:.3f} GB out of {:.3f} ' - 'GB total.'.format( - i + 1, - len(self.node_chunks[node_index]), - mem.used / 1e9, - mem.total / 1e9, - ) - ) + msg = ('Coordinate chunks regridded: {} out of {}. ' + 'Current memory usage is {:.3f} GB out of {:.3f} ' + 'GB total.'.format(i + 1, len(self.node_chunks[node_index]), + mem.used / 1e9, mem.total / 1e9, + )) logger.info(msg) def _run_parallel(self, source_files, node_index, max_workers=None): @@ -789,16 +736,14 @@ def _run_parallel(self, source_files, node_index, max_workers=None): logger.info('Regridding all coordinates in parallel.') with ThreadPoolExecutor(max_workers=max_workers) as exe: for i, chunk_index in enumerate(self.node_chunks[node_index]): - future = exe.submit( - self.write_coordinates, - source_files=source_files, - chunk_index=chunk_index, - ) + future = exe.submit(self.write_coordinates, + source_files=source_files, + chunk_index=chunk_index, + ) futures[future] = chunk_index mem = psutil.virtual_memory() msg = 'Regrid futures submitted: {} out of {}'.format( - i + 1, len(self.node_chunks[node_index]) - ) + i + 1, len(self.node_chunks[node_index])) logger.info(msg) logger.info(f'Submitted all regrid futures in {dt.now() - now}.') @@ -806,26 +751,19 @@ def _run_parallel(self, source_files, node_index, max_workers=None): for i, future in enumerate(as_completed(futures)): idx = futures[future] mem = psutil.virtual_memory() - msg = ( - 'Regrid futures completed: {} out of {}, in {}. ' - 'Current memory usage is {:.3f} GB out of {:.3f} GB ' - 'total.'.format( - i + 1, - len(futures), - dt.now() - now, - mem.used / 1e9, - mem.total / 1e9, - ) - ) + msg = ('Regrid futures completed: {} out of {}, in {}. ' + 'Current memory usage is {:.3f} GB out of {:.3f} GB ' + 'total.'.format(i + 1, len(futures), + dt.now() - now, mem.used / 1e9, + mem.total / 1e9, + )) logger.info(msg) try: future.result() except Exception as e: - msg = ( - 'Falied to regrid coordinate chunks with ' - 'index={index}'.format(index=idx) - ) + msg = ('Falied to regrid coordinate chunks with ' + 'index={index}'.format(index=idx)) logger.exception(msg) raise RuntimeError(msg) from e @@ -857,21 +795,18 @@ def write_coordinates(self, source_files, chunk_index): index_chunk=index_chunk, distance_chunk=distance_chunk, height=height, - source_files=source_files, - ) + source_files=source_files) features = [f'windspeed_{height}m', f'winddirection_{height}m'] for dset, data in zip(features, [ws, wd]): attrs, dtype = self.get_dset_attrs(dset) - fh.add_dataset( - tmp_file, - dset, - data, - dtype=dtype, - attrs=attrs, - chunks=attrs['chunks'], - ) + fh.add_dataset(tmp_file, + dset, + data, + dtype=dtype, + attrs=attrs, + chunks=attrs['chunks']) logger.info(f'Added {features} to {out_file}') os.replace(tmp_file, out_file) From e49879bfac6a4f2287d20f73c552e67aad57c770 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 31 Aug 2023 08:52:06 -0600 Subject: [PATCH 2/9] added multi handler and hr_spatial_coarsen to dual handler tests --- .../data_handling/test_dual_data_handling.py | 450 ++++++++---------- 1 file changed, 206 insertions(+), 244 deletions(-) diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index 02669afcc..adbbe0d93 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -9,14 +9,12 @@ from sup3r import TEST_DATA_DIR from sup3r.preprocessing.data_handling.dual_data_handling import ( - DualDataHandler, -) + DualDataHandler, ) from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5 from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC -from sup3r.preprocessing.dual_batch_handling import ( - DualBatchHandler, - SpatialDualBatchHandler, -) +from sup3r.preprocessing.dual_batch_handling import (DualBatchHandler, + SpatialDualBatchHandler, + ) from sup3r.utilities.utilities import spatial_coarsening FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') @@ -25,38 +23,41 @@ FEATURES = ['U_100m', 'V_100m'] -def test_dual_data_handler( - log=True, full_shape=(20, 20), sample_shape=(10, 10, 1), plot=True -): +def test_dual_data_handler(log=True, + full_shape=(20, 20), + sample_shape=(10, 10, 1), + plot=True): """Test basic spatial model training with only gen content loss.""" if log: init_logger('sup3r', log_level='DEBUG') # need to reduce the number of temporal examples to test faster - hr_handler = DataHandlerH5( - FP_WTK, - FEATURES, - target=TARGET_COORD, - shape=full_shape, - sample_shape=sample_shape, - temporal_slice=slice(None, None, 10), - worker_kwargs=dict(max_workers=1), - ) - lr_handler = DataHandlerNC( - FP_ERA, - FEATURES, - sample_shape=(sample_shape[0] // 2, sample_shape[1] // 2, 1), - temporal_slice=slice(None, None, 10), - worker_kwargs=dict(max_workers=1), - ) - - dual_handler = DualDataHandler( - hr_handler, lr_handler, s_enhance=2, t_enhance=1, val_split=0.1 - ) - - batch_handler = SpatialDualBatchHandler( - [dual_handler], batch_size=2, s_enhance=2, n_batches=10 - ) + hr_handler = DataHandlerH5(FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC(FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // 2, + sample_shape[1] // 2, 1), + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + + dual_handler = DualDataHandler(hr_handler, + lr_handler, + s_enhance=2, + t_enhance=1, + val_split=0.1) + + batch_handler = SpatialDualBatchHandler([dual_handler], + batch_size=2, + s_enhance=2, + n_batches=10) if plot: for i, batch in enumerate(batch_handler): @@ -66,77 +67,72 @@ def test_dual_data_handler( ax[0].imshow(np.mean(batch.high_res[..., -1], axis=0)) ax[1].set_title('Low Res') ax[1].imshow(np.mean(batch.low_res[..., -1], axis=0)) - fig.savefig( - f'./high_vs_low_{str(i).zfill(3)}.png', bbox_inches='tight' - ) + fig.savefig(f'./high_vs_low_{str(i).zfill(3)}.png', + bbox_inches='tight') -def test_regrid_caching( - log=True, full_shape=(20, 20), sample_shape=(10, 10, 1) -): +def test_regrid_caching(log=True, + full_shape=(20, 20), + sample_shape=(10, 10, 1)): """Test caching and loading of regridded data""" if log: init_logger('sup3r', log_level='DEBUG') # need to reduce the number of temporal examples to test faster with tempfile.TemporaryDirectory() as td: - hr_handler = DataHandlerH5( - FP_WTK, - FEATURES, - target=TARGET_COORD, - shape=full_shape, - sample_shape=sample_shape, - temporal_slice=slice(None, None, 10), - worker_kwargs=dict(max_workers=1), - ) - lr_handler = DataHandlerNC( - FP_ERA, - FEATURES, - sample_shape=(sample_shape[0] // 2, sample_shape[1] // 2, 1), - temporal_slice=slice(None, None, 10), - worker_kwargs=dict(max_workers=1), - ) - old_dh = DualDataHandler( - hr_handler, - lr_handler, - s_enhance=2, - t_enhance=1, - val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', - ) + hr_handler = DataHandlerH5(FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC(FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // 2, + sample_shape[1] // 2, 1), + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + old_dh = DualDataHandler(hr_handler, + lr_handler, + s_enhance=2, + t_enhance=1, + val_split=0.1, + regrid_cache_pattern=f'{td}/cache.pkl', + ) # Load handlers again - hr_handler = DataHandlerH5( - FP_WTK, - FEATURES, - target=TARGET_COORD, - shape=full_shape, - sample_shape=sample_shape, - temporal_slice=slice(None, None, 10), - worker_kwargs=dict(max_workers=1), - ) - lr_handler = DataHandlerNC( - FP_ERA, - FEATURES, - sample_shape=(sample_shape[0] // 2, sample_shape[1] // 2, 1), - temporal_slice=slice(None, None, 10), - worker_kwargs=dict(max_workers=1), - ) - new_dh = DualDataHandler( - hr_handler, - lr_handler, - s_enhance=2, - t_enhance=1, - val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', - ) + hr_handler = DataHandlerH5(FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC(FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // 2, + sample_shape[1] // 2, 1), + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + new_dh = DualDataHandler(hr_handler, + lr_handler, + s_enhance=2, + t_enhance=1, + val_split=0.1, + regrid_cache_pattern=f'{td}/cache.pkl', + ) assert np.array_equal(old_dh.lr_data, new_dh.lr_data) assert np.array_equal(old_dh.hr_data, new_dh.hr_data) -def test_st_dual_batch_handler( - log=False, full_shape=(20, 20), sample_shape=(10, 10, 4) -): +def test_st_dual_batch_handler(log=False, + full_shape=(20, 20), + sample_shape=(10, 10, 4)): """Test spatiotemporal dual batch handler.""" t_enhance = 2 s_enhance = 2 @@ -145,118 +141,106 @@ def test_st_dual_batch_handler( init_logger('sup3r', log_level='DEBUG') # need to reduce the number of temporal examples to test faster - hr_handler = DataHandlerH5( - FP_WTK, - FEATURES, - target=TARGET_COORD, - shape=full_shape, - sample_shape=sample_shape, - temporal_slice=slice(None, None, 10), - worker_kwargs=dict(max_workers=1), - ) - lr_handler = DataHandlerNC( - FP_ERA, - FEATURES, - sample_shape=( - sample_shape[0] // s_enhance, - sample_shape[1] // s_enhance, - sample_shape[2] // t_enhance, - ), - temporal_slice=slice(None, None, t_enhance * 10), - worker_kwargs=dict(max_workers=1), - ) - - dual_handler = DualDataHandler( - hr_handler, - lr_handler, - s_enhance=s_enhance, - t_enhance=t_enhance, - val_split=0.1, - ) - - batch_handler = DualBatchHandler( - [dual_handler], - batch_size=2, - s_enhance=s_enhance, - t_enhance=t_enhance, - n_batches=10, - ) + hr_handler = DataHandlerH5(FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1)) + lr_handler = DataHandlerNC(FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // s_enhance, + sample_shape[1] // s_enhance, + sample_shape[2] // t_enhance, + ), + temporal_slice=slice(None, None, + t_enhance * 10), + worker_kwargs=dict(max_workers=1)) + + dual_handler = DualDataHandler(hr_handler, + lr_handler, + s_enhance=s_enhance, + t_enhance=t_enhance, + val_split=0.1) + + batch_handler = DualBatchHandler([dual_handler, dual_handler], + batch_size=2, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=10) for batch in batch_handler: + + handler_index = batch_handler.current_handler_index + handler = batch_handler.data_handlers[handler_index] + for i, index in enumerate(batch_handler.current_batch_indices): hr_index = index['hr_index'] lr_index = index['lr_index'] coarse_lat_lon = spatial_coarsening( - dual_handler.hr_lat_lon[hr_index[:2]], obs_axis=False - ) - lr_lat_lon = dual_handler.lr_lat_lon[lr_index[:2]] + handler.hr_lat_lon[hr_index[:2]], obs_axis=False) + lr_lat_lon = handler.lr_lat_lon[lr_index[:2]] assert np.array_equal(coarse_lat_lon, lr_lat_lon) - coarse_ti = dual_handler.hr_time_index[hr_index[2]][::t_enhance] - lr_ti = dual_handler.lr_time_index[lr_index[2]] + coarse_ti = handler.hr_time_index[hr_index[2]][::t_enhance] + lr_ti = handler.lr_time_index[lr_index[2]] assert np.array_equal(coarse_ti.values, lr_ti.values) - assert np.array_equal( - batch.high_res[i], - dual_handler.hr_data[hr_index], - ) - assert np.array_equal( - batch.low_res[i], - dual_handler.lr_data[lr_index], - ) + assert np.array_equal(batch.high_res[i], handler.hr_data[hr_index]) + assert np.array_equal(batch.low_res[i], handler.lr_data[lr_index]) -def test_spatial_dual_batch_handler( - log=False, full_shape=(20, 20), sample_shape=(10, 10, 1), plot=True -): +def test_spatial_dual_batch_handler(log=False, + full_shape=(20, 20), + sample_shape=(10, 10, 1), + plot=True): """Test spatial dual batch handler.""" if log: init_logger('sup3r', log_level='DEBUG') # need to reduce the number of temporal examples to test faster - hr_handler = DataHandlerH5( - FP_WTK, - FEATURES, - target=TARGET_COORD, - shape=full_shape, - sample_shape=sample_shape, - temporal_slice=slice(None, None, 10), - worker_kwargs=dict(max_workers=1), - ) - lr_handler = DataHandlerNC( - FP_ERA, - FEATURES, - sample_shape=(sample_shape[0] // 2, sample_shape[1] // 2, 1), - temporal_slice=slice(None, None, 10), - worker_kwargs=dict(max_workers=1), - ) - - dual_handler = DualDataHandler( - hr_handler, lr_handler, s_enhance=2, t_enhance=1, val_split=0.0 - ) - - batch_handler = SpatialDualBatchHandler( - [dual_handler], batch_size=2, s_enhance=2, t_enhance=1, n_batches=10 - ) + hr_handler = DataHandlerH5(FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + hr_spatial_coarsen=2, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1)) + lr_handler = DataHandlerNC(FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // 2, + sample_shape[1] // 2, 1), + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1)) + + dual_handler = DualDataHandler(hr_handler, + lr_handler, + s_enhance=2, + t_enhance=1, + val_split=0.0, + shuffle_time=True) + + batch_handler = SpatialDualBatchHandler([dual_handler], + batch_size=2, + s_enhance=2, + t_enhance=1, + n_batches=10) for i, batch in enumerate(batch_handler): for j, index in enumerate(batch_handler.current_batch_indices): hr_index = index['hr_index'] lr_index = index['lr_index'] - assert np.array_equal( - batch.high_res[j, :, :], - dual_handler.hr_data[hr_index][..., 0, :], - ) - assert np.array_equal( - batch.low_res[j, :, :], - dual_handler.lr_data[lr_index][..., 0, :], - ) + assert np.array_equal(batch.high_res[j, :, :], + dual_handler.hr_data[hr_index][..., 0, :]) + assert np.array_equal(batch.low_res[j, :, :], + dual_handler.lr_data[lr_index][..., 0, :]) coarse_lat_lon = spatial_coarsening( - dual_handler.hr_lat_lon[hr_index[:2]], obs_axis=False - ) + dual_handler.hr_lat_lon[hr_index[:2]], obs_axis=False) lr_lat_lon = dual_handler.lr_lat_lon[lr_index[:2]] assert np.array_equal(coarse_lat_lon, lr_lat_lon) @@ -273,9 +257,9 @@ def test_spatial_dual_batch_handler( plt.close() -def test_validation_batching( - log=False, full_shape=(20, 20), sample_shape=(10, 10, 4) -): +def test_validation_batching(log=False, + full_shape=(20, 20), + sample_shape=(10, 10, 4)): """Test batching of validation data for dual batch handler""" if log: init_logger('sup3r', log_level='DEBUG') @@ -283,86 +267,64 @@ def test_validation_batching( s_enhance = 2 t_enhance = 2 - hr_handler = DataHandlerH5( - FP_WTK, - FEATURES, - target=TARGET_COORD, - shape=full_shape, - sample_shape=sample_shape, - temporal_slice=slice(None, None, 10), - worker_kwargs=dict(max_workers=1), - ) - lr_handler = DataHandlerNC( - FP_ERA, - FEATURES, - sample_shape=( - sample_shape[0] // s_enhance, - sample_shape[1] // s_enhance, - sample_shape[2] // t_enhance, - ), - temporal_slice=slice(None, None, t_enhance * 10), - worker_kwargs=dict(max_workers=1), - ) - - dual_handler = DualDataHandler( - hr_handler, - lr_handler, - s_enhance=s_enhance, - t_enhance=t_enhance, - val_split=0.1, - ) - - batch_handler = DualBatchHandler( - [dual_handler], - batch_size=2, - s_enhance=s_enhance, - t_enhance=t_enhance, - n_batches=10, - ) + hr_handler = DataHandlerH5(FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1)) + lr_handler = DataHandlerNC(FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // s_enhance, + sample_shape[1] // s_enhance, + sample_shape[2] // t_enhance), + temporal_slice=slice(None, None, + t_enhance * 10), + worker_kwargs=dict(max_workers=1)) + + dual_handler = DualDataHandler(hr_handler, + lr_handler, + s_enhance=s_enhance, + t_enhance=t_enhance, + val_split=0.1) + + batch_handler = DualBatchHandler([dual_handler], + batch_size=2, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=10) for batch in batch_handler.val_data: assert batch.high_res.dtype == np.dtype(np.float32) assert batch.low_res.dtype == np.dtype(np.float32) assert batch.low_res.shape[0] == batch.high_res.shape[0] - assert batch.low_res.shape == ( - batch.low_res.shape[0], - sample_shape[0] // s_enhance, - sample_shape[1] // s_enhance, - sample_shape[2] // t_enhance, - len(FEATURES), - ) - assert batch.high_res.shape == ( - batch.high_res.shape[0], - sample_shape[0], - sample_shape[1], - sample_shape[2], - len(FEATURES), - ) + assert batch.low_res.shape == (batch.low_res.shape[0], + sample_shape[0] // s_enhance, + sample_shape[1] // s_enhance, + sample_shape[2] // t_enhance, + len(FEATURES)) + assert batch.high_res.shape == (batch.high_res.shape[0], + sample_shape[0], sample_shape[1], + sample_shape[2], len(FEATURES)) for j, index in enumerate( - batch_handler.val_data.current_batch_indices - ): + batch_handler.val_data.current_batch_indices): hr_index = index['hr_index'] lr_index = index['lr_index'] - assert np.array_equal( - batch.high_res[j], - dual_handler.hr_val_data[hr_index], - ) - assert np.array_equal( - batch.low_res[j], - dual_handler.lr_val_data[lr_index], - ) + assert np.array_equal(batch.high_res[j], + dual_handler.hr_val_data[hr_index]) + assert np.array_equal(batch.low_res[j], + dual_handler.lr_val_data[lr_index]) coarse_lat_lon = spatial_coarsening( - dual_handler.hr_lat_lon[hr_index[:2]], obs_axis=False - ) + dual_handler.hr_lat_lon[hr_index[:2]], obs_axis=False) lr_lat_lon = dual_handler.lr_lat_lon[lr_index[:2]] assert np.array_equal(coarse_lat_lon, lr_lat_lon) - coarse_ti = dual_handler.hr_val_time_index[hr_index[2]][ - ::t_enhance - ] + coarse_ti = dual_handler.hr_val_time_index[ + hr_index[2]][::t_enhance] lr_ti = dual_handler.lr_val_time_index[lr_index[2]] assert np.array_equal(coarse_ti.values, lr_ti.values) From 5868508d33fdd51f2f43bdf705806c13402b6a82 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 31 Aug 2023 09:56:27 -0600 Subject: [PATCH 3/9] numpy version issue with regridder indexing --- requirements.txt | 1 - sup3r/utilities/regridder.py | 6 ++---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2da8f8c9c..06de3fb18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,3 @@ netCDF4==1.5.8 dask sphinx pandas -numpy==1.22 diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index 003cf51d4..7254e56c8 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -691,8 +691,7 @@ def run(self, node_index): else: self._run_parallel(source_files=self.source_files, node_index=node_index, - max_workers=self.regrid_workers, - ) + max_workers=self.regrid_workers) def _run_serial(self, source_files, node_index): """Regrid data and write to output file, in serial. @@ -714,8 +713,7 @@ def _run_serial(self, source_files, node_index): msg = ('Coordinate chunks regridded: {} out of {}. ' 'Current memory usage is {:.3f} GB out of {:.3f} ' 'GB total.'.format(i + 1, len(self.node_chunks[node_index]), - mem.used / 1e9, mem.total / 1e9, - )) + mem.used / 1e9, mem.total / 1e9)) logger.info(msg) def _run_parallel(self, source_files, node_index, max_workers=None): From ee1638a2bf076af752cbe49a5eb81fa8a7af1a24 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 1 Sep 2023 15:30:43 -0600 Subject: [PATCH 4/9] bc for dual data handler --- sup3r/bias/bias_transforms.py | 101 ++++++++---------- sup3r/preprocessing/data_handling/base.py | 8 +- .../data_handling/dual_data_handling.py | 28 ++++- 3 files changed, 79 insertions(+), 58 deletions(-) diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index 6fe7165fd..71d74f568 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -47,21 +47,18 @@ def get_spatial_bc_factors(lat_lon, feature_name, bias_fp, threshold=0.1): slice_x = slice(idx[0], idx[0] + lat_lon.shape[1]) if diff.min() > threshold: - msg = ( - 'The DataHandler top left coordinate of {} ' - 'appears to be {} away from the nearest ' - 'bias correction coordinate of {} from {}. ' - 'Cannot apply bias correction.'.format( - lat_lon, - diff.min(), - lat_lon_bc[idy, idx], - os.path.basename(bias_fp), - ) - ) + msg = ('The DataHandler top left coordinate of {} ' + 'appears to be {} away from the nearest ' + 'bias correction coordinate of {} from {}. ' + 'Cannot apply bias correction.'.format( + lat_lon, diff.min(), lat_lon_bc[idy, idx], + os.path.basename(bias_fp), + )) logger.error(msg) raise RuntimeError(msg) - assert dset_scalar in res.dsets and dset_adder in res.dsets + msg = (f'Either {dset_scalar} or {dset_adder} not found in {bias_fp}.') + assert dset_scalar in res.dsets and dset_adder in res.dsets, msg scalar = res[dset_scalar, slice_y, slice_x] adder = res[dset_adder, slice_y, slice_x] return scalar, adder @@ -94,15 +91,14 @@ def global_linear_bc(input, scalar, adder, out_range=None): return out -def local_linear_bc( - input, - lat_lon, - feature_name, - bias_fp, - lr_padded_slice, - out_range=None, - smoothing=0, -): +def local_linear_bc(input, + lat_lon, + feature_name, + bias_fp, + lr_padded_slice, + out_range=None, + smoothing=0, + ): """Bias correct data using a simple annual (or multi-year) *scalar +adder method on a site-by-site basis. @@ -156,10 +152,8 @@ def local_linear_bc( adder = adder[spatial_slice] if np.isnan(scalar).any() or np.isnan(adder).any(): - msg = ( - 'Bias correction scalar/adder values had NaNs for ' - f'"{feature_name}" from: {bias_fp}' - ) + msg = ('Bias correction scalar/adder values had NaNs for ' + f'"{feature_name}" from: {bias_fp}') logger.warning(msg) warn(msg) @@ -171,12 +165,12 @@ def local_linear_bc( if smoothing > 0: for idt in range(scalar.shape[-1]): - scalar[..., idt] = gaussian_filter( - scalar[..., idt], smoothing, mode='nearest' - ) - adder[..., idt] = gaussian_filter( - adder[..., idt], smoothing, mode='nearest' - ) + scalar[..., idt] = gaussian_filter(scalar[..., idt], + smoothing, + mode='nearest') + adder[..., idt] = gaussian_filter(adder[..., idt], + smoothing, + mode='nearest') out = input * scalar + adder if out_range is not None: @@ -186,17 +180,16 @@ def local_linear_bc( return out -def monthly_local_linear_bc( - input, - lat_lon, - feature_name, - bias_fp, - lr_padded_slice, - time_index, - temporal_avg=True, - out_range=None, - smoothing=0, -): +def monthly_local_linear_bc(input, + lat_lon, + feature_name, + bias_fp, + lr_padded_slice, + time_index, + temporal_avg=True, + out_range=None, + smoothing=0, + ): """Bias correct data using a simple monthly *scalar +adder method on a site-by-site basis. @@ -269,29 +262,25 @@ def monthly_local_linear_bc( scalar = np.repeat(scalar, input.shape[-1], axis=-1) adder = np.repeat(adder, input.shape[-1], axis=-1) if len(time_index.month.unique()) > 2: - msg = ( - 'Bias correction method "monthly_local_linear_bc" was used ' - 'with temporal averaging over a time index with >2 months.' - ) + msg = ('Bias correction method "monthly_local_linear_bc" was used ' + 'with temporal averaging over a time index with >2 months.') warn(msg) logger.warning(msg) if np.isnan(scalar).any() or np.isnan(adder).any(): - msg = ( - 'Bias correction scalar/adder values had NaNs for ' - f'"{feature_name}" from: {bias_fp}' - ) + msg = ('Bias correction scalar/adder values had NaNs for ' + f'"{feature_name}" from: {bias_fp}') logger.warning(msg) warn(msg) if smoothing > 0: for idt in range(scalar.shape[-1]): - scalar[..., idt] = gaussian_filter( - scalar[..., idt], smoothing, mode='nearest' - ) - adder[..., idt] = gaussian_filter( - adder[..., idt], smoothing, mode='nearest' - ) + scalar[..., idt] = gaussian_filter(scalar[..., idt], + smoothing, + mode='nearest') + adder[..., idt] = gaussian_filter(adder[..., idt], + smoothing, + mode='nearest') out = input * scalar + adder if out_range is not None: diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index d2251f3c8..b935bc790 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd +from rex import Resource from rex.utilities import log_mem from rex.utilities.fun_utils import get_fun_call_str from scipy.spatial import KDTree @@ -1151,7 +1152,12 @@ def lin_bc(self, bc_files, threshold=0.1): completed = [] for idf, feature in enumerate(self.features): for fp in bc_files: - if feature not in completed: + dset_scalar = f'{feature}_scalar' + dset_adder = f'{feature}_adder' + with Resource(fp) as res: + check = (dset_scalar in res.dsets + and dset_adder in res.dsets) + if feature not in completed and check: scalar, adder = get_spatial_bc_factors( lat_lon=self.lat_lon, feature_name=feature, diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 522edd70e..967068c8b 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -29,6 +29,8 @@ def __init__(self, shuffle_time=False, s_enhance=15, t_enhance=1, + bc_files=None, + bc_threshold=0.1, val_split=0.0): """Initialize data handler using hr and lr data handlers for h5 data and nc data @@ -54,6 +56,19 @@ def __init__(self, Spatial enhancement factor t_enhance : int Temporal enhancement factor + bc_files : list | tuple | str | None + One or more filepaths to .h5 files output by + MonthlyLinearCorrection or LinearCorrection. Used to bias correct + low resolution data prior to regrdding. These should contain + datasets named "{feature}_scalar" and "{feature}_adder" where + {feature} is one of the features contained by this DataHandler and + the data is a 3D array of shape (lat, lon, time) where time is + length 1 for annual correction or 12 for monthly correction. Bias + correction is only run if bc_files is not None. + bc_threshold : float + Nearest neighbor euclidean distance threshold. If the DataHandler + coordinates are more than this value away from the bias correction + lat/lon, an error is raised. val_split : float Percentage of data to reserve for validation. """ @@ -78,6 +93,8 @@ def __init__(self, self.hr_time_index = None self.lr_val_time_index = None self.hr_val_time_index = None + self.bc_files = bc_files + self.bc_threshold = bc_threshold if self.try_load and self.load_cached: self.load_cached_data() @@ -281,14 +298,23 @@ def features(self): @property def data(self): """Get low res data. Same as self.lr_data but used to match property - used by batch handler""" + used by batch handler for computing means and stdevs""" return self.lr_data + @data.setter + def data(self, data): + """Set low res data. Same as lr_data.setter but used to match property + used by batch handler for computing means and stdevs""" + self.lr_data = data + @property def lr_input_data(self): """Get low res data used as input to regridding routine""" if self.lr_dh.data is None: self.lr_dh.load_cached_data() + if self.bc_files is not None: + logger.info('Running bias correction on low resolution data.') + self.lr_dh.lin_bc(self.bc_files, self.bc_threshold) return self.lr_dh.data[..., :self.lr_required_shape[2], :] @property From 16a818a7b28d74c723da05b21f3cb6ff8d7a2516 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 2 Sep 2023 11:31:27 -0600 Subject: [PATCH 5/9] missed high res normalization --- sup3r/models/base.py | 214 +++++++++++------- .../data_handling/dual_data_handling.py | 41 ++-- .../data_handling/test_dual_data_handling.py | 60 +++++ 3 files changed, 212 insertions(+), 103 deletions(-) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 7f7220cc7..9e97b18ec 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -21,11 +21,20 @@ class Sup3rGan(AbstractInterface, AbstractSingleModel): """Basic sup3r GAN model.""" - def __init__(self, gen_layers, disc_layers, loss='MeanSquaredError', - optimizer=None, learning_rate=1e-4, - optimizer_disc=None, learning_rate_disc=None, - history=None, meta=None, means=None, stdevs=None, - default_device=None, name=None): + def __init__(self, + gen_layers, + disc_layers, + loss='MeanSquaredError', + optimizer=None, + learning_rate=1e-4, + optimizer_disc=None, + learning_rate_disc=None, + history=None, + meta=None, + means=None, + stdevs=None, + default_device=None, + name=None): """ Parameters ---------- @@ -107,10 +116,10 @@ def __init__(self, gen_layers, disc_layers, loss='MeanSquaredError', self._gen = self.load_network(gen_layers, 'generator') self._disc = self.load_network(disc_layers, 'discriminator') - self._means = (means if means is None - else np.array(means).astype(np.float32)) - self._stdevs = (stdevs if stdevs is None - else np.array(stdevs).astype(np.float32)) + self._means = (means if means is None else np.array(means).astype( + np.float32)) + self._stdevs = (stdevs if stdevs is None else np.array(stdevs).astype( + np.float32)) def save(self, out_dir): """Save the GAN with its sub-networks to a directory. @@ -158,10 +167,10 @@ def load(cls, model_dir, verbose=True): Returns a pretrained gan model that was previously saved to out_dir """ if verbose: - logger.info('Loading GAN from disk in directory: {}' - .format(model_dir)) - msg = ('Active python environment versions: \n{}' - .format(pprint.pformat(VERSION_RECORD, indent=4))) + logger.info( + 'Loading GAN from disk in directory: {}'.format(model_dir)) + msg = ('Active python environment versions: \n{}'.format( + pprint.pformat(VERSION_RECORD, indent=4))) logger.info(msg) fp_gen = os.path.join(model_dir, 'model_gen.pkl') @@ -170,7 +179,10 @@ def load(cls, model_dir, verbose=True): return cls(fp_gen, fp_disc, **params) - def generate(self, low_res, norm_in=True, un_norm_out=True, + def generate(self, + low_res, + norm_in=True, + un_norm_out=True, exogenous_data=None): """Use the generator model to generate high res data from low res input. This is the public generate function. @@ -202,8 +214,8 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ exo_check = (exogenous_data is None or not self._needs_lr_exo(low_res)) - low_res = (low_res if exo_check - else np.concatenate((low_res, exogenous_data), axis=-1)) + low_res = (low_res if exo_check else np.concatenate( + (low_res, exogenous_data), axis=-1)) if norm_in and self._means is not None: low_res = self.norm_input(low_res) @@ -213,8 +225,8 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, try: hi_res = layer(hi_res) except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, hi_res.shape)) + msg = ('Could not run layer #{} "{}" on tensor of shape {}'. + format(i + 1, layer, hi_res.shape)) logger.error(msg) raise RuntimeError(msg) from e @@ -245,8 +257,8 @@ def _tf_generate(self, low_res): try: hi_res = layer(hi_res) except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, hi_res.shape)) + msg = ('Could not run layer #{} "{}" on tensor of shape {}'. + format(i + 1, layer, hi_res.shape)) logger.error(msg) raise RuntimeError(msg) from e @@ -305,8 +317,8 @@ def discriminate(self, hi_res, norm_in=False): try: out = layer(out) except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, out.shape)) + msg = ('Could not run layer #{} "{}" on tensor of shape {}'. + format(i + 1, layer, out.shape)) logger.error(msg) raise RuntimeError(msg) from e @@ -337,8 +349,8 @@ def _tf_discriminate(self, hi_res): try: out = layer(out) except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, out.shape)) + msg = ('Could not run layer #{} "{}" on tensor of shape {}'. + format(i + 1, layer, out.shape)) logger.error(msg) raise RuntimeError(msg) from e @@ -397,23 +409,24 @@ def model_params(self): ------- dict """ - means = (self._means if self._means is None - else [float(m) for m in self._means]) - stdevs = (self._stdevs if self._stdevs is None - else [float(s) for s in self._stdevs]) + means = (self._means + if self._means is None else [float(m) for m in self._means]) + stdevs = (self._stdevs if self._stdevs is None else + [float(s) for s in self._stdevs]) config_optm_g = self.get_optimizer_config(self.optimizer) config_optm_d = self.get_optimizer_config(self.optimizer_disc) - model_params = {'name': self.name, - 'loss': self.loss_name, - 'version_record': self.version_record, - 'optimizer': config_optm_g, - 'optimizer_disc': config_optm_d, - 'means': means, - 'stdevs': stdevs, - 'meta': self.meta, - } + model_params = { + 'name': self.name, + 'loss': self.loss_name, + 'version_record': self.version_record, + 'optimizer': config_optm_g, + 'optimizer_disc': config_optm_d, + 'means': means, + 'stdevs': stdevs, + 'meta': self.meta, + } return model_params @@ -454,7 +467,8 @@ def init_weights(self, lr_shape, hr_shape, device=None): _ = self._tf_discriminate(hi_res) @staticmethod - def get_weight_update_fraction(history, comparison_key, + def get_weight_update_fraction(history, + comparison_key, update_bounds=(0.5, 0.95), update_frac=0.0): """Get the factor by which to multiply previous adversarial loss @@ -568,8 +582,9 @@ def calc_loss_disc(disc_out_true, disc_out_gen): # note that these have flipped labels from the generator # loss because of the opposite optimization goal logits = tf.concat([disc_out_true, disc_out_gen], axis=0) - labels = tf.concat([tf.ones_like(disc_out_true), - tf.zeros_like(disc_out_gen)], axis=0) + labels = tf.concat( + [tf.ones_like(disc_out_true), + tf.zeros_like(disc_out_gen)], axis=0) loss_disc = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) @@ -577,8 +592,12 @@ def calc_loss_disc(disc_out_true, disc_out_gen): return loss_disc - def calc_loss(self, hi_res_true, hi_res_gen, weight_gen_advers=0.001, - train_gen=True, train_disc=False): + def calc_loss(self, + hi_res_true, + hi_res_gen, + weight_gen_advers=0.001, + train_gen=True, + train_disc=False): """Calculate the GAN loss function using generated and true high resolution data. @@ -610,8 +629,8 @@ def calc_loss(self, hi_res_true, hi_res_gen, weight_gen_advers=0.001, msg = ('The tensor shapes of the synthetic output {} and ' 'true high res {} did not have matching shape! ' 'Check the spatiotemporal enhancement multipliers in your ' - 'your model config and data handlers.' - .format(hi_res_gen.shape, hi_res_true.shape)) + 'your model config and data handlers.'.format( + hi_res_gen.shape, hi_res_true.shape)) logger.error(msg) raise RuntimeError(msg) @@ -630,11 +649,12 @@ def calc_loss(self, hi_res_true, hi_res_gen, weight_gen_advers=0.001, elif train_disc: loss = loss_disc - loss_details = {'loss_gen': loss_gen, - 'loss_gen_content': loss_gen_content, - 'loss_gen_advers': loss_gen_advers, - 'loss_disc': loss_disc, - } + loss_details = { + 'loss_gen': loss_gen, + 'loss_gen_content': loss_gen_content, + 'loss_gen_advers': loss_gen_advers, + 'loss_disc': loss_disc, + } return loss, loss_details @@ -661,9 +681,11 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): for val_batch in batch_handler.val_data: output_gen = self._tf_generate(val_batch.low_res) _, v_loss_details = self.calc_loss( - val_batch.high_res, output_gen, + val_batch.high_res, + output_gen, weight_gen_advers=weight_gen_advers, - train_gen=False, train_disc=False) + train_gen=False, + train_disc=False) loss_details = self.update_loss_details(loss_details, v_loss_details, @@ -672,8 +694,13 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): return loss_details - def train_epoch(self, batch_handler, weight_gen_advers, train_gen, - train_disc, disc_loss_bounds, multi_gpu=False): + def train_epoch(self, + batch_handler, + weight_gen_advers, + train_gen, + train_disc, + disc_loss_bounds, + multi_gpu=False): """Train the GAN for one epoch. Parameters @@ -727,19 +754,25 @@ def train_epoch(self, batch_handler, weight_gen_advers, train_gen, if only_gen or (train_gen and not gen_too_good): trained_gen = True b_loss_details = self.run_gradient_descent( - batch.low_res, batch.high_res, self.generator_weights, + batch.low_res, + batch.high_res, + self.generator_weights, weight_gen_advers=weight_gen_advers, optimizer=self.optimizer, - train_gen=True, train_disc=False, + train_gen=True, + train_disc=False, multi_gpu=multi_gpu) if only_disc or (train_disc and not disc_too_good): trained_disc = True b_loss_details = self.run_gradient_descent( - batch.low_res, batch.high_res, self.discriminator_weights, + batch.low_res, + batch.high_res, + self.discriminator_weights, weight_gen_advers=weight_gen_advers, optimizer=self.optimizer_disc, - train_gen=False, train_disc=True, + train_gen=False, + train_disc=True, multi_gpu=multi_gpu) b_loss_details['gen_trained_frac'] = float(trained_gen) @@ -751,24 +784,23 @@ def train_epoch(self, batch_handler, weight_gen_advers, train_gen, logger.debug('Batch {} out of {} has epoch-average ' '(gen / disc) loss of: ({:.2e} / {:.2e}). ' - 'Trained (gen / disc): ({} / {})' - .format(ib, len(batch_handler), - loss_details['train_loss_gen'], - loss_details['train_loss_disc'], - trained_gen, trained_disc)) + 'Trained (gen / disc): ({} / {})'.format( + ib, len(batch_handler), + loss_details['train_loss_gen'], + loss_details['train_loss_disc'], trained_gen, + trained_disc)) if all([not trained_gen, not trained_disc]): msg = ('For some reason none of the GAN networks trained ' - 'during batch {} out of {}!' - .format(ib, len(batch_handler))) + 'during batch {} out of {}!'.format( + ib, len(batch_handler))) logger.warning(msg) warn(msg) return loss_details def update_adversarial_weights(self, history, adaptive_update_fraction, - adaptive_update_bounds, - weight_gen_advers, + adaptive_update_bounds, weight_gen_advers, train_disc): """Update spatial / temporal adversarial loss weights based on training fraction history. @@ -805,7 +837,8 @@ def update_adversarial_weights(self, history, adaptive_update_fraction, update_frac = 1 if train_disc: update_frac = self.get_weight_update_fraction( - history, 'train_disc_trained_frac', + history, + 'train_disc_trained_frac', update_frac=adaptive_update_fraction, update_bounds=adaptive_update_bounds) weight_gen_advers *= update_frac @@ -816,7 +849,9 @@ def update_adversarial_weights(self, history, adaptive_update_fraction, return weight_gen_advers - def train(self, batch_handler, n_epoch, + def train(self, + batch_handler, + n_epoch, weight_gen_advers=0.001, train_gen=True, train_disc=True, @@ -900,20 +935,21 @@ def train(self, batch_handler, n_epoch, epochs = list(range(n_epoch)) if self._history is None: - self._history = pd.DataFrame( - columns=['elapsed_time']) + self._history = pd.DataFrame(columns=['elapsed_time']) self._history.index.name = 'epoch' else: epochs += self._history.index.values[-1] + 1 t0 = time.time() logger.info('Training model with adversarial weight: {} ' - 'for {} epochs starting at epoch {}' - .format(weight_gen_advers, n_epoch, epochs[0])) + 'for {} epochs starting at epoch {}'.format( + weight_gen_advers, n_epoch, epochs[0])) for epoch in epochs: - loss_details = self.train_epoch(batch_handler, weight_gen_advers, - train_gen, train_disc, + loss_details = self.train_epoch(batch_handler, + weight_gen_advers, + train_gen, + train_disc, disc_loss_bounds, multi_gpu=multi_gpu) @@ -925,8 +961,8 @@ def train(self, batch_handler, n_epoch, loss_details["train_loss_gen"], loss_details["train_loss_disc"]) - if all(loss in loss_details for loss - in ('val_loss_gen', 'val_loss_disc')): + if all(loss in loss_details + for loss in ('val_loss_gen', 'val_loss_disc')): msg += 'gen/disc val loss: {:.2e}/{:.2e} '.format( loss_details["val_loss_gen"], loss_details["val_loss_disc"]) @@ -937,20 +973,28 @@ def train(self, batch_handler, n_epoch, lr_d = self.get_optimizer_config( self.optimizer_disc)['learning_rate'] - extras = {'weight_gen_advers': weight_gen_advers, - 'disc_loss_bound_0': disc_loss_bounds[0], - 'disc_loss_bound_1': disc_loss_bounds[1], - 'learning_rate_gen': lr_g, - 'learning_rate_disc': lr_d} + extras = { + 'weight_gen_advers': weight_gen_advers, + 'disc_loss_bound_0': disc_loss_bounds[0], + 'disc_loss_bound_1': disc_loss_bounds[1], + 'learning_rate_gen': lr_g, + 'learning_rate_disc': lr_d + } weight_gen_advers = self.update_adversarial_weights( loss_details, adaptive_update_fraction, adaptive_update_bounds, weight_gen_advers, train_disc) - stop = self.finish_epoch(epoch, epochs, t0, loss_details, - checkpoint_int, out_dir, - early_stop_on, early_stop_threshold, - early_stop_n_epoch, extras=extras) + stop = self.finish_epoch(epoch, + epochs, + t0, + loss_details, + checkpoint_int, + out_dir, + early_stop_on, + early_stop_threshold, + early_stop_n_epoch, + extras=extras) if stop: break diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 967068c8b..5f5786d50 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -85,6 +85,7 @@ def __init__(self, self.shuffle_time = shuffle_time self._lr_lat_lon = None self._hr_lat_lon = None + self._lr_input_data = None self.lr_data = None self.hr_data = None self.lr_val_data = None @@ -163,7 +164,7 @@ def _val_split_check(self): warn(msg) def normalize(self, means, stdevs): - """Normalize low_res data + """Normalize low_res and high_res data Parameters ---------- @@ -174,11 +175,18 @@ def normalize(self, means, stdevs): dimensions (features) array of means for all features with same ordering as data features """ + logger.info('Normalizing low resolution data.') self._normalize(data=self.lr_data, val_data=self.lr_val_data, means=means, stds=stdevs, max_workers=self.lr_dh.norm_workers) + logger.info('Normalizing high resolution data.') + self._normalize(data=self.hr_data, + val_data=self.hr_val_data, + means=means, + stds=stdevs, + max_workers=self.hr_dh.norm_workers) @property def output_features(self): @@ -301,26 +309,18 @@ def data(self): used by batch handler for computing means and stdevs""" return self.lr_data - @data.setter - def data(self, data): - """Set low res data. Same as lr_data.setter but used to match property - used by batch handler for computing means and stdevs""" - self.lr_data = data - @property def lr_input_data(self): """Get low res data used as input to regridding routine""" - if self.lr_dh.data is None: - self.lr_dh.load_cached_data() - if self.bc_files is not None: - logger.info('Running bias correction on low resolution data.') - self.lr_dh.lin_bc(self.bc_files, self.bc_threshold) - return self.lr_dh.data[..., :self.lr_required_shape[2], :] - - @property - def shape(self): - """Get low_res shape""" - return self.lr_data.shape + if self._lr_input_data is None: + if self.lr_dh.data is None: + self.lr_dh.load_cached_data() + if self.bc_files is not None: + logger.info('Running bias correction on low resolution data.') + self.lr_dh.lin_bc(self.bc_files, self.bc_threshold) + self._lr_input_data = self.lr_dh.data[ + ..., :self.lr_required_shape[2], :] + return self._lr_input_data @property def lr_required_shape(self): @@ -329,6 +329,11 @@ def lr_required_shape(self): self.hr_dh.requested_shape[1] // self.s_enhance, self.hr_dh.requested_shape[2] // self.t_enhance) + @property + def shape(self): + """Get low_res shape""" + return (*self.lr_required_shape, len(self.features)) + @property def hr_required_shape(self): """Return required shape for high_res data""" diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index adbbe0d93..a72a12c37 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -328,3 +328,63 @@ def test_validation_batching(log=False, hr_index[2]][::t_enhance] lr_ti = dual_handler.lr_val_time_index[lr_index[2]] assert np.array_equal(coarse_ti.values, lr_ti.values) + + +def test_normalization(log=False, + full_shape=(20, 20), + sample_shape=(10, 10, 4)): + """Test correct normalization""" + if log: + init_logger('sup3r', log_level='DEBUG') + + s_enhance = 2 + t_enhance = 2 + + hr_handler = DataHandlerH5(FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1)) + lr_handler = DataHandlerNC(FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // s_enhance, + sample_shape[1] // s_enhance, + sample_shape[2] // t_enhance), + temporal_slice=slice(None, None, + t_enhance * 10), + worker_kwargs=dict(max_workers=1)) + + dual_handler = DualDataHandler(hr_handler, + lr_handler, + s_enhance=s_enhance, + t_enhance=t_enhance, + val_split=0.1) + + means = [ + np.nanmean(dual_handler.lr_data[..., i]) + for i in range(dual_handler.lr_data.shape[-1]) + ] + stdevs = [ + np.nanstd(dual_handler.lr_data[..., i] - means[i]) + for i in range(dual_handler.lr_data.shape[-1]) + ] + + batch_handler = DualBatchHandler([dual_handler], + batch_size=2, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=10) + assert np.allclose(batch_handler.means, means) + assert np.allclose(batch_handler.stds, stdevs) + stacked_data = np.concatenate( + [d.data for d in batch_handler.data_handlers], axis=2) + + for i in range(len(FEATURES)): + std = np.std(stacked_data[..., i]) + if std == 0: + std = 1 + mean = np.mean(stacked_data[..., i]) + assert np.allclose(std, 1, atol=1e-3) + assert np.allclose(mean, 0, atol=1e-3) From 01c133ee86a28740ade2acc2b401a310392cfc47 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 5 Sep 2023 10:22:39 -0600 Subject: [PATCH 6/9] get_lat_lon_df method for help with obs comparisons --- sup3r/preprocessing/data_handling/base.py | 60 ++++ .../data_handling/nc_data_handling.py | 282 ++++++------------ 2 files changed, 153 insertions(+), 189 deletions(-) diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index b935bc790..3c646eaa6 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -669,6 +669,66 @@ def preflight(self): f'load_workers={self.load_workers}, ' f'ti_workers={self.ti_workers}') + @staticmethod + def get_closest_lat_lon(lat_lon, target): + """Get closest indices to target lat lon + + Parameters + ---------- + lat_lon : ndarray + Array of lat/lon + (spatial_1, spatial_2, 2) + Last dimension in order of (lat, lon) + target : tuple + (lat, lon) for target coordinate + + Returns + ------- + row : int + row index for closest lat/lon to target lat/lon + col : int + col index for closest lat/lon to target lat/lon + """ + # shape of ll2 is (n, 2) where axis=1 is (lat, lon) + ll2 = np.vstack((lat_lon[..., 0].flatten(), lat_lon[..., + 1].flatten())).T + tree = KDTree(ll2) + _, i = tree.query(np.array(target)) + row, col = np.where((lat_lon[..., 0] == ll2[i, 0]) + & (lat_lon[..., 1] == ll2[i, 1])) + row = row[0] + col = col[0] + return row, col + + def get_lat_lon_df(self, target, features=None): + """Get timeseries for given target + + Parameters + ---------- + target : tuple + (lat, lon) for target coordinate + features : list | None + Optional list of features to include in returned data. If None then + all available features are returned. + + Returns + ------- + df : pd.DataFrame + Pandas dataframe with columns for each feature and timeindex for + the given target + """ + row, col = self.get_closest_lat_lon(self.lat_lon, target) + df = pd.DataFrame() + df['time'] = self.time_index + if self.data is None: + self.load_cached_data() + data = self.data[row, col] + features = features if features is not None else self.features + for f in features: + i = self.features.index(f) + df[f] = data[:, i] + return df + @classmethod def get_lat_lon(cls, file_paths, raster_index, invert_lat=False): """Get lat/lon grid for requested target and shape diff --git a/sup3r/preprocessing/data_handling/nc_data_handling.py b/sup3r/preprocessing/data_handling/nc_data_handling.py index d655f4c5f..e83795986 100644 --- a/sup3r/preprocessing/data_handling/nc_data_handling.py +++ b/sup3r/preprocessing/data_handling/nc_data_handling.py @@ -18,17 +18,15 @@ from scipy.stats import mode from sup3r.preprocessing.data_handling.base import DataHandler, DataHandlerDC -from sup3r.preprocessing.feature_handling import (BVFreqMon, BVFreqSquaredNC, - ClearSkyRatioCC, Feature, - InverseMonNC, LatLonNC, - PotentialTempNC, PressureNC, - Rews, Shear, Tas, TasMax, - TasMin, TempNC, TempNCforCC, - UWind, VWind, - WinddirectionNC, WindspeedNC) +from sup3r.preprocessing.feature_handling import ( + BVFreqMon, BVFreqSquaredNC, ClearSkyRatioCC, Feature, InverseMonNC, + LatLonNC, PotentialTempNC, PressureNC, Rews, Shear, Tas, TasMax, TasMin, + TempNC, TempNCforCC, UWind, VWind, WinddirectionNC, WindspeedNC, +) from sup3r.utilities.interpolation import Interpolator from sup3r.utilities.utilities import (estimate_max_workers, get_time_dim_name, - np_to_pd_times) + np_to_pd_times, + ) np.random.seed(42) @@ -80,9 +78,8 @@ def extract_workers(self): proc_mem /= len(self.time_chunks) n_procs = len(self.time_chunks) * len(self.extract_features) n_procs = int(np.ceil(n_procs)) - extract_workers = estimate_max_workers( - self._extract_workers, proc_mem, n_procs - ) + extract_workers = estimate_max_workers(self._extract_workers, proc_mem, + n_procs) return extract_workers @classmethod @@ -146,10 +143,8 @@ def get_file_times(cls, file_paths, **kwargs): elif hasattr(handle, 'times'): time_index = np_to_pd_times(handle.times.values) else: - msg = ( - f'Could not get time_index for {file_paths}. ' - 'Assuming time independence.' - ) + msg = (f'Could not get time_index for {file_paths}. ' + 'Assuming time independence.') time_index = None logger.warning(msg) warnings.warn(msg) @@ -177,11 +172,8 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): time_index : pd.Datetimeindex List of times as a Datetimeindex """ - max_workers = ( - len(file_paths) - if max_workers is None - else np.min((max_workers, len(file_paths))) - ) + max_workers = (len(file_paths) if max_workers is None else np.min( + (max_workers, len(file_paths)))) if max_workers == 1: return cls.get_file_times(file_paths, **kwargs) ti = {} @@ -192,10 +184,8 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): future = exe.submit(cls.get_file_times, [f], **kwargs) futures[future] = {'idx': i, 'file': f} - logger.info( - f'Started building time index from {len(file_paths)} ' - f'files in {dt.now() - now}.' - ) + logger.info(f'Started building time index from {len(file_paths)} ' + f'files in {dt.now() - now}.') for i, future in enumerate(as_completed(futures)): try: @@ -203,10 +193,8 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): if val is not None: ti[futures[future]['idx']] = list(val) except Exception as e: - msg = ( - 'Error while getting time index from file ' - f'{futures[future]["file"]}.' - ) + msg = ('Error while getting time index from file ' + f'{futures[future]["file"]}.') logger.exception(msg) raise RuntimeError(msg) from e logger.debug(f'Stored {i+1} out of {len(futures)} file times') @@ -242,14 +230,13 @@ def feature_registry(cls): return registry @classmethod - def extract_feature( - cls, - file_paths, - raster_index, - feature, - time_slice=slice(None), - **kwargs, - ): + def extract_feature(cls, + file_paths, + raster_index, + feature, + time_slice=slice(None), + **kwargs, + ): """Extract single feature from data source. The requested feature can match exactly to one found in the source data or can have a matching prefix with a suffix specifying the height or pressure level @@ -278,10 +265,8 @@ def extract_feature( Data array for extracted feature (spatial_1, spatial_2, temporal) """ - logger.debug( - f'Extracting {feature} with time_slice={time_slice}, ' - f'raster_index={raster_index}, kwargs={kwargs}.' - ) + logger.debug(f'Extracting {feature} with time_slice={time_slice}, ' + f'raster_index={raster_index}, kwargs={kwargs}.') handle = cls.source_handler(file_paths, **kwargs) f_info = Feature(feature, handle) interp_height = f_info.height @@ -290,27 +275,20 @@ def extract_feature( if feature in handle or feature.lower() in handle: feat_key = feature if feature in handle else feature.lower() - fdata = cls.direct_extract( - handle, feat_key, raster_index, time_slice - ) + fdata = cls.direct_extract(handle, feat_key, raster_index, + time_slice) elif basename in handle or basename.lower() in handle: feat_key = basename if basename in handle else basename.lower() if interp_height is not None: fdata = Interpolator.interp_var_to_height( - handle, - feat_key, - raster_index, - np.float32(interp_height), + handle, feat_key, raster_index, np.float32(interp_height), time_slice, ) elif interp_pressure is not None: fdata = Interpolator.interp_var_to_pressure( - handle, - feat_key, - raster_index, - np.float32(interp_pressure), - time_slice, + handle, feat_key, raster_index, + np.float32(interp_pressure), time_slice, ) else: @@ -374,40 +352,6 @@ def get_full_domain(cls, file_paths): """ return cls.get_lat_lon(file_paths, [slice(None), slice(None)]) - @staticmethod - def get_closest_lat_lon(lat_lon, target): - """Get closest indices to target lat lon to use for lower left corner - of raster index - - Parameters - ---------- - lat_lon : ndarray - Array of lat/lon - (spatial_1, spatial_2, 2) - Last dimension in order of (lat, lon) - target : tuple - (lat, lon) for lower left corner - - Returns - ------- - row : int - row index for closest lat/lon to target lat/lon - col : int - col index for closest lat/lon to target lat/lon - """ - # shape of ll2 is (n, 2) where axis=1 is (lat, lon) - ll2 = np.vstack( - (lat_lon[..., 0].flatten(), lat_lon[..., 1].flatten()) - ).T - tree = KDTree(ll2) - _, i = tree.query(np.array(target)) - row, col = np.where( - (lat_lon[..., 0] == ll2[i, 0]) & (lat_lon[..., 1] == ll2[i, 1]) - ) - row = row[0] - col = col[0] - return row, col - @classmethod def compute_raster_index(cls, file_paths, target, grid_shape): """Get raster index for a given target and shape @@ -426,9 +370,9 @@ def compute_raster_index(cls, file_paths, target, grid_shape): list List of slices corresponding to extracted data region """ - lat_lon = cls.get_lat_lon( - file_paths[:1], [slice(None), slice(None)], invert_lat=False - ) + lat_lon = cls.get_lat_lon(file_paths[:1], + [slice(None), slice(None)], + invert_lat=False) cls._check_grid_extent(target, grid_shape, lat_lon) row, col = cls.get_closest_lat_lon(lat_lon, target) @@ -473,20 +417,14 @@ def _check_grid_extent(cls, target, grid_shape, lat_lon): min_lon = np.min(lat_lon[..., 1]) max_lat = np.max(lat_lon[..., 0]) max_lon = np.max(lat_lon[..., 1]) - logger.debug( - 'Calculating raster index from WRF file ' - f'for shape {grid_shape} and target {target}' - ) - logger.debug( - f'lat/lon (min, max): {min_lat}/{min_lon}, ' f'{max_lat}/{max_lon}' - ) - msg = ( - f'target {target} out of bounds with min lat/lon ' - f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}' - ) - assert ( - min_lat <= target[0] <= max_lat and min_lon <= target[1] <= max_lon - ), msg + logger.debug('Calculating raster index from WRF file ' + f'for shape {grid_shape} and target {target}') + logger.debug(f'lat/lon (min, max): {min_lat}/{min_lon}, ' + f'{max_lat}/{max_lon}') + msg = (f'target {target} out of bounds with min lat/lon ' + f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}') + assert (min_lat <= target[0] <= max_lat + and min_lon <= target[1] <= max_lon), msg @classmethod def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): @@ -506,20 +444,15 @@ def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): raster_index : list List of slices selecting region from entire available grid. """ - if ( - raster_index[0].stop > lat_lon.shape[0] - or raster_index[1].stop > lat_lon.shape[1] - or raster_index[0].start < 0 - or raster_index[1].start < 0 - ): - msg = ( - f'Invalid target {target}, shape {grid_shape}, and raster ' - f'{raster_index} for data domain of size ' - f'{lat_lon.shape[:-1]} with lower left corner ' - f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' - f' and upper right corner ({np.max(lat_lon[..., 0])}, ' - f'{np.max(lat_lon[..., 1])}).' - ) + if (raster_index[0].stop > lat_lon.shape[0] + or raster_index[1].stop > lat_lon.shape[1] + or raster_index[0].start < 0 or raster_index[1].start < 0): + msg = (f'Invalid target {target}, shape {grid_shape}, and raster ' + f'{raster_index} for data domain of size ' + f'{lat_lon.shape[:-1]} with lower left corner ' + f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' + f' and upper right corner ({np.max(lat_lon[..., 0])}, ' + f'{np.max(lat_lon[..., 1])}).') raise ValueError(msg) def get_raster_index(self): @@ -532,33 +465,23 @@ def get_raster_index(self): raster_index : np.ndarray 2D array of grid indices """ - self.raster_file = ( - self.raster_file - if self.raster_file is None - else self.raster_file.replace('.txt', '.npy') - ) + self.raster_file = (self.raster_file if self.raster_file is None else + self.raster_file.replace('.txt', '.npy')) if self.raster_file is not None and os.path.exists(self.raster_file): - logger.debug( - f'Loading raster index: {self.raster_file} ' - f'for {self.input_file_info}' - ) + logger.debug(f'Loading raster index: {self.raster_file} ' + f'for {self.input_file_info}') raster_index = np.load(self.raster_file, allow_pickle=True) raster_index = list(raster_index) else: check = self.grid_shape is not None and self.target is not None - msg = ( - 'Must provide raster file or shape + target to get ' - 'raster index' - ) + msg = ('Must provide raster file or shape + target to get ' + 'raster index') assert check, msg - raster_index = self.compute_raster_index( - self.file_paths, self.target, self.grid_shape - ) - logger.debug( - 'Found raster index with row, col slices: {}'.format( - raster_index - ) - ) + raster_index = self.compute_raster_index(self.file_paths, + self.target, + self.grid_shape) + logger.debug('Found raster index with row, col slices: {}'.format( + raster_index)) if self.raster_file is not None: basedir = os.path.dirname(self.raster_file) @@ -578,14 +501,13 @@ class DataHandlerNCforCC(DataHandlerNC): Chunk sizes that approximately match the data volume being extracted typically results in the most efficient IO.""" - def __init__( - self, - *args, - nsrdb_source_fp=None, - nsrdb_agg=1, - nsrdb_smoothing=0, - **kwargs, - ): + def __init__(self, + *args, + nsrdb_source_fp=None, + nsrdb_agg=1, + nsrdb_smoothing=0, + **kwargs, + ): """Initialize NETCDF data handler for climate change data. Parameters @@ -709,30 +631,22 @@ def get_clearsky_ghi(self): shape is (lat, lon, time) where time is daily average values. """ - msg = ( - 'Need nsrdb_source_fp input arg as a valid filepath to ' - 'retrieve clearsky_ghi (maybe for clearsky_ratio) but ' - 'received: {}'.format(self._nsrdb_source_fp) - ) + msg = ('Need nsrdb_source_fp input arg as a valid filepath to ' + 'retrieve clearsky_ghi (maybe for clearsky_ratio) but ' + 'received: {}'.format(self._nsrdb_source_fp)) assert self._nsrdb_source_fp is not None, msg assert os.path.exists(self._nsrdb_source_fp), msg - msg = ( - 'Can only handle source CC data in hourly frequency but ' - 'received daily frequency of {}hrs (should be 24) ' - 'with raw time index: {}'.format( - self.time_freq_hours, self.raw_time_index - ) - ) + msg = ('Can only handle source CC data in hourly frequency but ' + 'received daily frequency of {}hrs (should be 24) ' + 'with raw time index: {}'.format(self.time_freq_hours, + self.raw_time_index)) assert self.time_freq_hours == 24.0, msg - msg = ( - 'Can only handle source CC data with temporal_slice.step == 1 ' - 'but received: {}'.format(self.temporal_slice.step) - ) - assert (self.temporal_slice.step is None) | ( - self.temporal_slice.step == 1 - ), msg + msg = ('Can only handle source CC data with temporal_slice.step == 1 ' + 'but received: {}'.format(self.temporal_slice.step)) + assert (self.temporal_slice.step is None) | (self.temporal_slice.step + == 1), msg with Resource(self._nsrdb_source_fp) as res: ti_nsrdb = res.time_index @@ -758,15 +672,11 @@ def get_clearsky_ghi(self): if len(i.shape) == 1: i = np.expand_dims(i, axis=1) - logger.info( - 'Extracting clearsky_ghi data from "{}" with time slice ' - '{} and {} locations with agg factor {}.'.format( - os.path.basename(self._nsrdb_source_fp), - t_slice, - i.shape[0], - i.shape[1], - ) - ) + logger.info('Extracting clearsky_ghi data from "{}" with time slice ' + '{} and {} locations with agg factor {}.'.format( + os.path.basename(self._nsrdb_source_fp), t_slice, + i.shape[0], i.shape[1], + )) cs_shape = i.shape with Resource(self._nsrdb_source_fp) as res: @@ -775,9 +685,8 @@ def get_clearsky_ghi(self): cs_ghi = cs_ghi.reshape((len(cs_ghi), *cs_shape)) cs_ghi = cs_ghi.mean(axis=-1) - windows = np.array_split( - np.arange(len(cs_ghi)), len(cs_ghi) // (24 // time_freq) - ) + windows = np.array_split(np.arange(len(cs_ghi)), + len(cs_ghi) // (24 // time_freq)) cs_ghi = [cs_ghi[window].mean(axis=0) for window in windows] cs_ghi = np.vstack(cs_ghi) cs_ghi = cs_ghi.reshape((len(cs_ghi), *tuple(self.grid_shape))) @@ -786,15 +695,12 @@ def get_clearsky_ghi(self): if self.invert_lat: cs_ghi = cs_ghi[::-1] - logger.info( - 'Smoothing nsrdb clearsky ghi with a factor of {}'.format( - self._nsrdb_smoothing - ) - ) + logger.info('Smoothing nsrdb clearsky ghi with a factor of {}'.format( + self._nsrdb_smoothing)) for iday in range(cs_ghi.shape[-1]): - cs_ghi[..., iday] = gaussian_filter( - cs_ghi[..., iday], self._nsrdb_smoothing, mode='nearest' - ) + cs_ghi[..., iday] = gaussian_filter(cs_ghi[..., iday], + self._nsrdb_smoothing, + mode='nearest') if cs_ghi.shape[-1] < t_end_target: n = int(np.ceil(t_end_target / cs_ghi.shape[-1])) @@ -805,9 +711,7 @@ def get_clearsky_ghi(self): 'Reshaped clearsky_ghi data to final shape {} to ' 'correspond with CC daily average data over source ' 'temporal_slice {} with (lat, lon) grid shape of {}'.format( - cs_ghi.shape, self.temporal_slice, self.grid_shape - ) - ) + cs_ghi.shape, self.temporal_slice, self.grid_shape)) return cs_ghi From 0cf4b414119eb36a94f472477bb066bc1c408ef5 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 5 Sep 2023 18:30:31 -0600 Subject: [PATCH 7/9] added weighted selection of data handler in batch handler --- sup3r/preprocessing/batch_handling.py | 621 +++++++----------- sup3r/preprocessing/data_handling/base.py | 11 + .../data_handling/dual_data_handling.py | 5 + sup3r/preprocessing/feature_handling.py | 133 ++-- sup3r/utilities/era_downloader.py | 101 +-- .../data_handling/test_dual_data_handling.py | 1 + 6 files changed, 372 insertions(+), 500 deletions(-) diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index c89196a62..266a90c54 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -13,19 +13,11 @@ from scipy.ndimage.filters import gaussian_filter from sup3r.preprocessing.data_handling.h5_data_handling import ( - DataHandlerDCforH5, -) + DataHandlerDCforH5, ) from sup3r.utilities.utilities import ( - estimate_max_workers, - nn_fill_array, - nsrdb_reduce_daily_data, - smooth_data, - spatial_coarsening, - temporal_coarsening, - uniform_box_sampler, - uniform_time_sampler, - weighted_box_sampler, - weighted_time_sampler, + estimate_max_workers, nn_fill_array, nsrdb_reduce_daily_data, smooth_data, + spatial_coarsening, temporal_coarsening, uniform_box_sampler, + uniform_time_sampler, weighted_box_sampler, weighted_time_sampler, ) np.random.seed(42) @@ -94,18 +86,17 @@ def reduce_features(high_res, output_features_ind=None): # pylint: disable=W0613 @classmethod - def get_coarse_batch( - cls, - high_res, - s_enhance, - t_enhance=1, - temporal_coarsening_method='subsample', - output_features_ind=None, - output_features=None, - training_features=None, - smoothing=None, - smoothing_ignore=None, - ): + def get_coarse_batch(cls, + high_res, + s_enhance, + t_enhance=1, + temporal_coarsening_method='subsample', + output_features_ind=None, + output_features=None, + training_features=None, + smoothing=None, + smoothing_ignore=None, + ): """Coarsen high res data and return Batch with high res and low res data @@ -155,13 +146,11 @@ def get_coarse_batch( smoothing_ignore = [] if t_enhance != 1: - low_res = temporal_coarsening( - low_res, t_enhance, temporal_coarsening_method - ) + low_res = temporal_coarsening(low_res, t_enhance, + temporal_coarsening_method) - low_res = smooth_data( - low_res, training_features, smoothing_ignore, smoothing - ) + low_res = smooth_data(low_res, training_features, smoothing_ignore, + smoothing) high_res = cls.reduce_features(high_res, output_features_ind) batch = cls(low_res, high_res) @@ -174,18 +163,16 @@ class ValidationData: # Classes to use for handling an individual batch obj. BATCH_CLASS = Batch - def __init__( - self, - data_handlers, - batch_size=8, - s_enhance=3, - t_enhance=1, - temporal_coarsening_method='subsample', - output_features_ind=None, - output_features=None, - smoothing=None, - smoothing_ignore=None, - ): + def __init__(self, + data_handlers, + batch_size=8, + s_enhance=3, + t_enhance=1, + temporal_coarsening_method='subsample', + output_features_ind=None, + output_features=None, + smoothing=None, + smoothing_ignore=None): """ Parameters ---------- @@ -256,23 +243,32 @@ def _get_val_indices(self): if h.val_data is not None: for _ in range(h.val_data.shape[2]): spatial_slice = uniform_box_sampler( - h.val_data, self.sample_shape[:2] - ) + h.val_data, self.sample_shape[:2]) temporal_slice = uniform_time_sampler( - h.val_data, self.sample_shape[2] - ) - tuple_index = tuple( - [ - *spatial_slice, - temporal_slice, - np.arange(h.val_data.shape[-1]), - ] - ) - val_indices.append( - {'handler_index': i, 'tuple_index': tuple_index} - ) + h.val_data, self.sample_shape[2]) + tuple_index = tuple([ + *spatial_slice, temporal_slice, + np.arange(h.val_data.shape[-1]), + ]) + val_indices.append({ + 'handler_index': i, + 'tuple_index': tuple_index + }) return val_indices + @property + def handler_weights(self): + """Get weights used to sample from different data handlers based on + relative sizes""" + sizes = [dh.size for dh in self.data_handlers] + weights = sizes / np.sum(sizes) + return weights + + def get_handler_index(self): + """Get random handler index based on handler weights""" + indices = np.arange(0, len(self.data_handlers)) + return np.random.choice(indices, p=self.handler_weights) + def any(self): """Return True if any validation data exists""" return any(self.val_indices) @@ -291,12 +287,9 @@ def shape(self): time_steps = 0 for h in self.data_handlers: time_steps += h.val_data.shape[2] - return ( - self.data_handlers[0].val_data.shape[0], - self.data_handlers[0].val_data.shape[1], - time_steps, - self.data_handlers[0].val_data.shape[3], - ) + return (self.data_handlers[0].val_data.shape[0], + self.data_handlers[0].val_data.shape[1], time_steps, + self.data_handlers[0].val_data.shape[3]) def __iter__(self): self._i = 0 @@ -334,8 +327,7 @@ def batch_next(self, high_res): output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features, - ) + output_features=self.output_features) def __next__(self): """Get validation data batch @@ -354,20 +346,13 @@ def __next__(self): n_obs = self._remaining_observations high_res = np.zeros( - ( - n_obs, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.data_handlers[0].shape[-1], - ), - dtype=np.float32, - ) + (n_obs, self.sample_shape[0], self.sample_shape[1], + self.sample_shape[2], self.data_handlers[0].shape[-1]), + dtype=np.float32) for i in range(high_res.shape[0]): val_index = self.val_indices[self._i + i] - high_res[i, ...] = self.data_handlers[ - val_index['handler_index'] - ].val_data[val_index['tuple_index']] + high_res[i, ...] = self.data_handlers[val_index[ + 'handler_index']].val_data[val_index['tuple_index']] self._remaining_observations -= 1 self.current_batch_indices.append(val_index['handler_index']) @@ -388,24 +373,22 @@ class BatchHandler: BATCH_CLASS = Batch DATA_HANDLER_CLASS = None - def __init__( - self, - data_handlers, - batch_size=8, - s_enhance=3, - t_enhance=1, - means=None, - stds=None, - norm=True, - n_batches=10, - temporal_coarsening_method='subsample', - stdevs_file=None, - means_file=None, - overwrite_stats=False, - smoothing=None, - smoothing_ignore=None, - worker_kwargs=None, - ): + def __init__(self, + data_handlers, + batch_size=8, + s_enhance=3, + t_enhance=1, + means=None, + stds=None, + norm=True, + n_batches=10, + temporal_coarsening_method='subsample', + stdevs_file=None, + means_file=None, + overwrite_stats=False, + smoothing=None, + smoothing_ignore=None, + worker_kwargs=None): """ Parameters ---------- @@ -507,19 +490,17 @@ def __init__( f for f in self.training_features if f not in self.smoothing_ignore ] - logger.info( - f'Initializing BatchHandler with smoothing={smoothing}. ' - f'Using stats_workers={self.stats_workers}, ' - f'norm_workers={self.norm_workers}, ' - f'load_workers={self.load_workers}.' - ) + logger.info(f'Initializing BatchHandler with ' + f'{len(self.data_handlers)} data handlers with handler' + f'weights={self.handler_weights}, smoothing={smoothing}. ' + f'Using stats_workers={self.stats_workers}, ' + f'norm_workers={self.norm_workers}, ' + f'load_workers={self.load_workers}.') now = dt.now() self.parallel_load() - logger.debug( - f'Finished loading data of shape {self.shape} ' - f'for BatchHandler in {dt.now() - now}.' - ) + logger.debug(f'Finished loading data of shape {self.shape} ' + f'for BatchHandler in {dt.now() - now}.') log_mem(logger, log_level='INFO') if norm: @@ -542,6 +523,24 @@ def __init__( logger.info('Finished initializing BatchHandler.') log_mem(logger, log_level='INFO') + @property + def handler_weights(self): + """Get weights used to sample from different data handlers based on + relative sizes""" + sizes = [dh.size for dh in self.data_handlers] + weights = sizes / np.sum(sizes) + return weights + + def get_handler_index(self): + """Get random handler index based on handler weights""" + indices = np.arange(0, len(self.data_handlers)) + return np.random.choice(indices, p=self.handler_weights) + + def get_current_handler(self): + """Get random handler based on handler weights""" + self.current_handler_index = self.get_handler_index() + return self.data_handlers[self.current_handler_index] + @property def feature_mem(self): """Get memory used by each feature in data handlers""" @@ -551,18 +550,16 @@ def feature_mem(self): def stats_workers(self): """Get max workers for calculating stats based on memory usage""" proc_mem = self.feature_mem - stats_workers = estimate_max_workers( - self._stats_workers, proc_mem, len(self.data_handlers) - ) + stats_workers = estimate_max_workers(self._stats_workers, proc_mem, + len(self.data_handlers)) return stats_workers @property def load_workers(self): """Get max workers for loading data handler based on memory usage""" proc_mem = len(self.data_handlers[0].features) * self.feature_mem - max_workers = estimate_max_workers( - self._load_workers, proc_mem, len(self.data_handlers) - ) + max_workers = estimate_max_workers(self._load_workers, proc_mem, + len(self.data_handlers)) return max_workers @property @@ -570,9 +567,8 @@ def norm_workers(self): """Get max workers used for calculating and normalization across features""" proc_mem = 2 * self.feature_mem - norm_workers = estimate_max_workers( - self._norm_workers, proc_mem, len(self.training_features) - ) + norm_workers = estimate_max_workers(self._norm_workers, proc_mem, + len(self.training_features)) return norm_workers @property @@ -595,8 +591,7 @@ def output_features_ind(self): return None else: out = [ - i - for i, feature in enumerate(self.training_features) + i for i, feature in enumerate(self.training_features) if feature in self.output_features ] return out @@ -609,16 +604,13 @@ def shape(self): ------- shape : tuple (spatial_1, spatial_2, temporal, features) - With temporal extent equal to the sum across all data handlers time - dimension + With spatiotemporal extent equal to the sum across all data handler + dimensions """ time_steps = np.sum([h.shape[-2] for h in self.data_handlers]) - return ( - self.data_handlers[0].shape[0], - self.data_handlers[0].shape[1], - time_steps, - self.data_handlers[0].shape[-1], - ) + n_lons = self.data_handlers[0].shape[1] + n_lats = self.data_handlers[0].shape[0] + return (n_lats, n_lons, time_steps, self.data_handlers[0].shape[-1]) def parallel_normalization(self): """Normalize data in all data handlers in parallel.""" @@ -635,25 +627,19 @@ def parallel_normalization(self): future = exe.submit(d.normalize, self.means, self.stds) futures[future] = i - logger.info( - f'Started normalizing {len(self.data_handlers)} ' - f'data handlers in {dt.now() - now}.' - ) + logger.info(f'Started normalizing {len(self.data_handlers)} ' + f'data handlers in {dt.now() - now}.') for i, _ in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ( - 'Error normalizing data handler number ' - f'{futures[future]}' - ) + msg = ('Error normalizing data handler number ' + f'{futures[future]}') logger.exception(msg) raise RuntimeError(msg) from e - logger.debug( - f'{i+1} out of {len(futures)} data handlers' - ' normalized.' - ) + logger.debug(f'{i+1} out of {len(futures)} data handlers' + ' normalized.') def parallel_load(self): """Load data handler data in parallel""" @@ -672,31 +658,25 @@ def parallel_load(self): future = exe.submit(d.load_cached_data) futures[future] = i - logger.info( - f'Started loading all {len(self.data_handlers)} ' - f'data handlers in {dt.now() - now}.' - ) + logger.info(f'Started loading all {len(self.data_handlers)} ' + f'data handlers in {dt.now() - now}.') for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ( - 'Error loading data handler number ' - f'{futures[future]}' - ) + msg = ('Error loading data handler number ' + f'{futures[future]}') logger.exception(msg) raise RuntimeError(msg) from e - logger.debug( - f'{i+1} out of {len(futures)} handlers ' 'loaded.' - ) + logger.debug(f'{i+1} out of {len(futures)} handlers ' + 'loaded.') def parallel_stats(self): """Get standard deviations and means for training features in parallel.""" - logger.info( - f'Calculating stats for {len(self.training_features)} ' 'features.' - ) + logger.info(f'Calculating stats for {len(self.training_features)} ' + 'features.') max_workers = self.norm_workers if max_workers == 1: for f in self.training_features: @@ -709,27 +689,21 @@ def parallel_stats(self): future = exe.submit(self.get_stats_for_feature, f) futures[future] = i - logger.info( - 'Started calculating stats for ' - f'{len(self.training_features)} features in ' - f'{dt.now() - now}.' - ) + logger.info('Started calculating stats for ' + f'{len(self.training_features)} features in ' + f'{dt.now() - now}.') for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ( - 'Error calculating stats for ' - f'{self.training_features[futures[future]]}' - ) + msg = ('Error calculating stats for ' + f'{self.training_features[futures[future]]}') logger.exception(msg) raise RuntimeError(msg) from e - logger.debug( - f'{i+1} out of ' - f'{len(self.training_features)} stats ' - 'calculated.' - ) + logger.debug(f'{i+1} out of ' + f'{len(self.training_features)} stats ' + 'calculated.') def __len__(self): """Use user input of n_batches to specify length @@ -752,9 +726,8 @@ def check_cached_stats(self): stds : ndarray Array of stdevs for each feature """ - stdevs_check = ( - self.stdevs_file is not None and not self.overwrite_stats - ) + stdevs_check = (self.stdevs_file is not None + and not self.overwrite_stats) stdevs_check = stdevs_check and os.path.exists(self.stdevs_file) means_check = self.means_file is not None and not self.overwrite_stats means_check = means_check and os.path.exists(self.means_file) @@ -766,12 +739,10 @@ def check_cached_stats(self): with open(self.means_file, 'rb') as fh: self.means = pickle.load(fh) - msg = ( - 'The training features and cached statistics are ' - 'incompatible. Number of training features is ' - f'{len(self.training_features)} and number of stats is' - f' {len(self.stds)}' - ) + msg = ('The training features and cached statistics are ' + 'incompatible. Number of training features is ' + f'{len(self.training_features)} and number of stats is' + f' {len(self.stds)}') check = len(self.means) == len(self.training_features) check = check and (len(self.stds) == len(self.training_features)) assert check, msg @@ -822,9 +793,8 @@ def get_handler_mean(self, feature_idx, handler_idx): float Feature mean """ - return np.nanmean( - self.data_handlers[handler_idx].data[..., feature_idx] - ) + return np.nanmean(self.data_handlers[handler_idx].data[..., + feature_idx]) def get_handler_variance(self, feature_idx, handler_idx, mean): """Get feature variance for a given handler @@ -887,18 +857,14 @@ def get_means_for_feature(self, feature, max_workers=None): future = exe.submit(self.get_handler_mean, idx, didx) futures[future] = didx - logger.info( - 'Started calculating means for ' - f'{len(self.data_handlers)} data_handlers in ' - f'{dt.now() - now}.' - ) + logger.info('Started calculating means for ' + f'{len(self.data_handlers)} data_handlers in ' + f'{dt.now() - now}.') for i, future in enumerate(as_completed(futures)): self.means[idx] += future.result() - logger.debug( - f'{i+1} out of {len(self.data_handlers)} ' - 'means calculated.' - ) + logger.debug(f'{i+1} out of {len(self.data_handlers)} ' + 'means calculated.') self.means[idx] /= len(self.data_handlers) return self.means[idx] @@ -918,30 +884,24 @@ def get_stdevs_for_feature(self, feature, max_workers=None): if max_workers == 1: for didx, _ in enumerate(self.data_handlers): self.stds[idx] += self.get_handler_variance( - idx, didx, self.means[idx] - ) + idx, didx, self.means[idx]) else: with ThreadPoolExecutor(max_workers=max_workers) as exe: futures = {} now = dt.now() for didx, _ in enumerate(self.data_handlers): - future = exe.submit( - self.get_handler_variance, idx, didx, self.means[idx] - ) + future = exe.submit(self.get_handler_variance, idx, didx, + self.means[idx]) futures[future] = didx - logger.info( - 'Started calculating stdevs for ' - f'{len(self.data_handlers)} data_handlers in ' - f'{dt.now() - now}.' - ) + logger.info('Started calculating stdevs for ' + f'{len(self.data_handlers)} data_handlers in ' + f'{dt.now() - now}.') for i, future in enumerate(as_completed(futures)): self.stds[idx] += future.result() - logger.debug( - f'{i+1} out of {len(self.data_handlers)} ' - 'stdevs calculated.' - ) + logger.debug(f'{i+1} out of {len(self.data_handlers)} ' + 'stdevs calculated.') self.stds[idx] /= len(self.data_handlers) self.stds[idx] = np.sqrt(self.stds[idx]) return self.stds[idx] @@ -962,18 +922,15 @@ def normalize(self, means=None, stds=None): self.get_stats() elif means is not None and stds is not None: if not np.array_equal(means, self.means) or not np.array_equal( - stds, self.stds - ): + stds, self.stds): self.unnormalize() self.means = means self.stds = stds now = dt.now() logger.info('Normalizing data in each data handler.') self.parallel_normalization() - logger.info( - 'Finished normalizing data in all data handlers in ' - f'{dt.now() - now}.' - ) + logger.info('Finished normalizing data in all data handlers in ' + f'{dt.now() - now}.') def unnormalize(self): """Remove normalization from stored means and stds""" @@ -995,19 +952,11 @@ def __next__(self): """ self.current_batch_indices = [] if self._i < self.n_batches: - handler_index = np.random.randint(0, len(self.data_handlers)) - self.current_handler_index = handler_index - handler = self.data_handlers[handler_index] + handler = self.get_current_handler() high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.shape[-1], - ), - dtype=np.float32, - ) + (self.batch_size, self.sample_shape[0], self.sample_shape[1], + self.sample_shape[2], self.shape[-1]), + dtype=np.float32) for i in range(self.batch_size): high_res[i, ...] = handler.get_next() @@ -1022,8 +971,7 @@ def __next__(self): output_features=self.output_features, training_features=self.training_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore, - ) + smoothing_ignore=self.smoothing_ignore) self._i += 1 return batch @@ -1070,9 +1018,7 @@ def __next__(self): if self._i >= self.n_batches: raise StopIteration - handler_index = np.random.randint(0, len(self.data_handlers)) - self.current_handler_index = handler_index - handler = self.data_handlers[handler_index] + handler = self.get_current_handler() low_res = None high_res = None @@ -1082,8 +1028,7 @@ def __next__(self): self.current_batch_indices.append(handler.current_obs_index) obs_hourly = self.BATCH_CLASS.reduce_features( - obs_hourly, self.output_features_ind - ) + obs_hourly, self.output_features_ind) if low_res is None: lr_shape = (self.batch_size, *obs_daily_avg.shape) @@ -1097,25 +1042,22 @@ def __next__(self): high_res = self.reduce_high_res_sub_daily(high_res) low_res = spatial_coarsening(low_res, self.s_enhance) - if ( - self.output_features is not None - and 'clearsky_ratio' in self.output_features - ): + if (self.output_features is not None + and 'clearsky_ratio' in self.output_features): i_cs = self.output_features.index('clearsky_ratio') if np.isnan(high_res[..., i_cs]).any(): high_res[..., i_cs] = nn_fill_array(high_res[..., i_cs]) if self.smoothing is not None: feat_iter = [ - j - for j in range(low_res.shape[-1]) + j for j in range(low_res.shape[-1]) if self.training_features[j] not in self.smoothing_ignore ] for i in range(low_res.shape[0]): for j in feat_iter: - low_res[i, ..., j] = gaussian_filter( - low_res[i, ..., j], self.smoothing, mode='nearest' - ) + low_res[i, ..., j] = gaussian_filter(low_res[i, ..., j], + self.smoothing, + mode='nearest') batch = self.BATCH_CLASS(low_res, high_res) @@ -1182,9 +1124,7 @@ def __next__(self): if self._i >= self.n_batches: raise StopIteration - handler_index = np.random.randint(0, len(self.data_handlers)) - self.current_handler_index = handler_index - handler = self.data_handlers[handler_index] + handler = self.get_current_handler() high_res = None @@ -1196,12 +1136,9 @@ def __next__(self): hr_shape = (self.batch_size, *obs_daily_avg.shape) high_res = np.zeros(hr_shape, dtype=np.float32) - msg = ( - 'SpatialBatchHandlerCC can only use n_temporal==1 ' - 'but received HR shape {} with n_temporal={}.'.format( - hr_shape, hr_shape[3] - ) - ) + msg = ('SpatialBatchHandlerCC can only use n_temporal==1 ' + 'but received HR shape {} with n_temporal={}.'.format( + hr_shape, hr_shape[3])) assert hr_shape[3] == 1, msg high_res[i] = obs_daily_avg @@ -1210,29 +1147,25 @@ def __next__(self): low_res = low_res[:, :, :, 0, :] high_res = high_res[:, :, :, 0, :] - high_res = self.BATCH_CLASS.reduce_features( - high_res, self.output_features_ind - ) + high_res = self.BATCH_CLASS.reduce_features(high_res, + self.output_features_ind) - if ( - self.output_features is not None - and 'clearsky_ratio' in self.output_features - ): + if (self.output_features is not None + and 'clearsky_ratio' in self.output_features): i_cs = self.output_features.index('clearsky_ratio') if np.isnan(high_res[..., i_cs]).any(): high_res[..., i_cs] = nn_fill_array(high_res[..., i_cs]) if self.smoothing is not None: feat_iter = [ - j - for j in range(low_res.shape[-1]) + j for j in range(low_res.shape[-1]) if self.training_features[j] not in self.smoothing_ignore ] for i in range(low_res.shape[0]): for j in feat_iter: - low_res[i, ..., j] = gaussian_filter( - low_res[i, ..., j], self.smoothing, mode='nearest' - ) + low_res[i, ..., j] = gaussian_filter(low_res[i, ..., j], + self.smoothing, + mode='nearest') batch = self.BATCH_CLASS(low_res, high_res) @@ -1245,17 +1178,10 @@ class SpatialBatchHandler(BatchHandler): def __next__(self): if self._i < self.n_batches: - handler_index = np.random.randint(0, len(self.data_handlers)) - handler = self.data_handlers[handler_index] - high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.shape[-1], - ), - dtype=np.float32, - ) + handler = self.get_current_handler() + high_res = np.zeros((self.batch_size, self.sample_shape[0], + self.sample_shape[1], self.shape[-1]), + dtype=np.float32) for i in range(self.batch_size): high_res[i, ...] = handler.get_next()[..., 0, :] @@ -1265,8 +1191,7 @@ def __next__(self): output_features_ind=self.output_features_ind, training_features=self.training_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore, - ) + smoothing_ignore=self.smoothing_ignore) self._i += 1 return batch @@ -1295,69 +1220,58 @@ def _get_val_indices(self): val_indices = {} for t in range(self.N_TIME_BINS): val_indices[t] = [] - h_idx = np.random.choice(np.arange(len(self.data_handlers))) + h_idx = self.get_handler_index() h = self.data_handlers[h_idx] for _ in range(self.batch_size): - spatial_slice = uniform_box_sampler( - h.data, self.sample_shape[:2] - ) + spatial_slice = uniform_box_sampler(h.data, + self.sample_shape[:2]) weights = np.zeros(self.N_TIME_BINS) weights[t] = 1 - temporal_slice = weighted_time_sampler( - h.data, self.sample_shape[2], weights - ) - tuple_index = tuple( - [ - *spatial_slice, - temporal_slice, - np.arange(h.data.shape[-1]), - ] - ) - val_indices[t].append( - {'handler_index': h_idx, 'tuple_index': tuple_index} - ) + temporal_slice = weighted_time_sampler(h.data, + self.sample_shape[2], + weights) + tuple_index = tuple([ + *spatial_slice, temporal_slice, + np.arange(h.data.shape[-1]) + ]) + val_indices[t].append({ + 'handler_index': h_idx, + 'tuple_index': tuple_index + }) for s in range(self.N_SPACE_BINS): val_indices[s + self.N_TIME_BINS] = [] - h_idx = np.random.choice(np.arange(len(self.data_handlers))) + h_idx = self.get_handler_index() h = self.data_handlers[h_idx] for _ in range(self.batch_size): weights = np.zeros(self.N_SPACE_BINS) weights[s] = 1 - spatial_slice = weighted_box_sampler( - h.data, self.sample_shape[:2], weights - ) - temporal_slice = uniform_time_sampler( - h.data, self.sample_shape[2] - ) - tuple_index = tuple( - [ - *spatial_slice, - temporal_slice, - np.arange(h.data.shape[-1]), - ] - ) - val_indices[s + self.N_TIME_BINS].append( - {'handler_index': h_idx, 'tuple_index': tuple_index} - ) + spatial_slice = weighted_box_sampler(h.data, + self.sample_shape[:2], + weights) + temporal_slice = uniform_time_sampler(h.data, + self.sample_shape[2]) + tuple_index = tuple([ + *spatial_slice, temporal_slice, + np.arange(h.data.shape[-1]) + ]) + val_indices[s + self.N_TIME_BINS].append({ + 'handler_index': + h_idx, + 'tuple_index': + tuple_index + }) return val_indices def __next__(self): if self._i < len(self.val_indices.keys()): high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.data_handlers[0].shape[-1], - ), - dtype=np.float32, - ) + (self.batch_size, self.sample_shape[0], self.sample_shape[1], + self.sample_shape[2], self.data_handlers[0].shape[-1]), + dtype=np.float32) val_indices = self.val_indices[self._i] for i, idx in enumerate(val_indices): high_res[i, ...] = self.data_handlers[ - idx['handler_index'] - ].data[idx['tuple_index']] + idx['handler_index']].data[idx['tuple_index']] batch = self.BATCH_CLASS.get_coarse_batch( high_res, @@ -1367,8 +1281,7 @@ def __next__(self): output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features, - ) + output_features=self.output_features) self._i += 1 return batch else: @@ -1389,19 +1302,13 @@ class ValidationDataSpatialDC(ValidationDataDC): def __next__(self): if self._i < len(self.val_indices.keys()): high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.data_handlers[0].shape[-1], - ), - dtype=np.float32, - ) + (self.batch_size, self.sample_shape[0], self.sample_shape[1], + self.data_handlers[0].shape[-1]), + dtype=np.float32) val_indices = self.val_indices[self._i] for i, idx in enumerate(val_indices): high_res[i, ...] = self.data_handlers[ - idx['handler_index'] - ].data[idx['tuple_index']][..., 0, :] + idx['handler_index']].data[idx['tuple_index']][..., 0, :] batch = self.BATCH_CLASS.get_coarse_batch( high_res, @@ -1409,8 +1316,7 @@ def __next__(self): output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features, - ) + output_features=self.output_features) self._i += 1 return batch else: @@ -1440,15 +1346,12 @@ def __init__(self, *args, **kwargs): self.old_temporal_weights = [0] * self.val_data.N_TIME_BINS bin_range = self.data_handlers[0].data.shape[2] bin_range -= self.sample_shape[2] - 1 - self.temporal_bins = np.array_split( - np.arange(0, bin_range), self.val_data.N_TIME_BINS - ) + self.temporal_bins = np.array_split(np.arange(0, bin_range), + self.val_data.N_TIME_BINS) self.temporal_bins = [b[0] for b in self.temporal_bins] - logger.info( - 'Using temporal weights: ' - f'{[round(w, 3) for w in self.temporal_weights]}' - ) + logger.info('Using temporal weights: ' + f'{[round(w, 3) for w in self.temporal_weights]}') self.temporal_sample_record = [0] * self.val_data.N_TIME_BINS self.norm_temporal_record = [0] * self.val_data.N_TIME_BINS @@ -1467,24 +1370,15 @@ def __iter__(self): def __next__(self): self.current_batch_indices = [] if self._i < self.n_batches: - handler_index = np.random.randint(0, len(self.data_handlers)) - self.current_handler_index = handler_index - handler = self.data_handlers[handler_index] + handler = self.get_current_handler() high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.shape[-1], - ), - dtype=np.float32, - ) + (self.batch_size, self.sample_shape[0], self.sample_shape[1], + self.sample_shape[2], self.shape[-1]), + dtype=np.float32) for i in range(self.batch_size): high_res[i, ...] = handler.get_next( - temporal_weights=self.temporal_weights - ) + temporal_weights=self.temporal_weights) self.current_batch_indices.append(handler.current_obs_index) self.update_training_sample_record() @@ -1498,8 +1392,7 @@ def __next__(self): output_features=self.output_features, training_features=self.training_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore, - ) + smoothing_ignore=self.smoothing_ignore) self._i += 1 return batch @@ -1538,15 +1431,12 @@ def __init__(self, *args, **kwargs): self.max_cols = self.data_handlers[0].data.shape[1] + 1 self.max_cols -= self.sample_shape[1] bin_range = self.max_rows * self.max_cols - self.spatial_bins = np.array_split( - np.arange(0, bin_range), self.val_data.N_SPACE_BINS - ) + self.spatial_bins = np.array_split(np.arange(0, bin_range), + self.val_data.N_SPACE_BINS) self.spatial_bins = [b[0] for b in self.spatial_bins] - logger.info( - 'Using spatial weights: ' - f'{[round(w, 3) for w in self.spatial_weights]}' - ) + logger.info('Using spatial weights: ' + f'{[round(w, 3) for w in self.spatial_weights]}') self.spatial_sample_record = [0] * self.val_data.N_SPACE_BINS self.norm_spatial_record = [0] * self.val_data.N_SPACE_BINS @@ -1568,23 +1458,16 @@ def __iter__(self): def __next__(self): self.current_batch_indices = [] if self._i < self.n_batches: - handler_index = np.random.randint(0, len(self.data_handlers)) - self.current_handler_index = handler_index - handler = self.data_handlers[handler_index] - high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.shape[-1], - ), - dtype=np.float32, - ) + handler = self.get_current_handler() + high_res = np.zeros((self.batch_size, self.sample_shape[0], + self.sample_shape[1], self.shape[-1], + ), + dtype=np.float32, + ) for i in range(self.batch_size): high_res[i, ...] = handler.get_next( - spatial_weights=self.spatial_weights - )[..., 0, :] + spatial_weights=self.spatial_weights)[..., 0, :] self.current_batch_indices.append(handler.current_obs_index) self.update_training_sample_record() diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index 3c646eaa6..6aff69635 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -922,6 +922,17 @@ def shape(self): """ return self.data.shape + @property + def size(self): + """Size of data array + + Returns + ------- + size : int + Number of total elements contained in data array + """ + return np.product(self.requested_shape) + def cache_data(self, cache_file_paths): """Cache feature data to file and delete from memory diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 5f5786d50..2e7a7cf3e 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -334,6 +334,11 @@ def shape(self): """Get low_res shape""" return (*self.lr_required_shape, len(self.features)) + @property + def size(self): + """Get low_res size""" + return np.product(self.shape) + @property def hr_required_shape(self): """Return required shape for high_res data""" diff --git a/sup3r/preprocessing/feature_handling.py b/sup3r/preprocessing/feature_handling.py index df784aa7a..bb2a27712 100644 --- a/sup3r/preprocessing/feature_handling.py +++ b/sup3r/preprocessing/feature_handling.py @@ -16,16 +16,10 @@ from rex import Resource from rex.utilities.execution import SpawnProcessPool -from sup3r.utilities.utilities import ( - bvf_squared, - get_raster_shape, - inverse_mo_length, - invert_pot_temp, - invert_uv, - rotor_equiv_ws, - transform_rotate_wind, - vorticity_calc, -) +from sup3r.utilities.utilities import (bvf_squared, get_raster_shape, + inverse_mo_length, invert_pot_temp, + invert_uv, rotor_equiv_ws, + transform_rotate_wind, vorticity_calc) np.random.seed(42) @@ -155,6 +149,7 @@ class CloudMaskH5(DerivedFeature): @classmethod def inputs(cls, feature): """Get list of raw features used in calculation of the cloud mask + Parameters ---------- feature : str @@ -243,8 +238,7 @@ class TempNC(DerivedFeature): def inputs(cls, feature): """Get list of inputs needed for compute method.""" height = Feature.get_height(feature) - features = [f'PotentialTemp_{height}m', - f'Pressure_{height}m'] + features = [f'PotentialTemp_{height}m', f'Pressure_{height}m'] return features @classmethod @@ -276,8 +270,7 @@ class PressureNC(DerivedFeature): def inputs(cls, feature): """Get list of inputs needed for compute method.""" height = Feature.get_height(feature) - features = [f'P_{height}m', - f'PB_{height}m'] + features = [f'P_{height}m', f'PB_{height}m'] return features @classmethod @@ -308,8 +301,7 @@ class BVFreqSquaredNC(DerivedFeature): def inputs(cls, feature): """Get list of inputs needed for compute method.""" height = Feature.get_height(feature) - features = [f'PT_{height}m', - f'PT_{int(height) - 100}m'] + features = [f'PT_{height}m', f'PT_{int(height) - 100}m'] return features @@ -449,10 +441,10 @@ def inputs(cls, feature): List of required features for computing BVF2 """ height = Feature.get_height(feature) - features = [f'temperature_{height}m', - f'temperature_{int(height) - 100}m', - f'pressure_{height}m', - f'pressure_{int(height) - 100}m'] + features = [ + f'temperature_{height}m', f'temperature_{int(height) - 100}m', + f'pressure_{height}m', f'pressure_{int(height) - 100}m' + ] return features @@ -473,12 +465,10 @@ def compute(cls, data, height): Derived feature array """ - return bvf_squared( - data[f'temperature_{height}m'], - data[f'temperature_{int(height) - 100}m'], - data[f'pressure_{height}m'], - data[f'pressure_{int(height) - 100}m'], - 100) + return bvf_squared(data[f'temperature_{height}m'], + data[f'temperature_{int(height) - 100}m'], + data[f'pressure_{height}m'], + data[f'pressure_{int(height) - 100}m'], 100) class WindspeedNC(DerivedFeature): @@ -753,8 +743,9 @@ def inputs(cls, feature): List of required features for computing U """ height = Feature.get_height(feature) - features = [f'windspeed_{height}m', f'winddirection_{height}m', - 'lat_lon'] + features = [ + f'windspeed_{height}m', f'winddirection_{height}m', 'lat_lon' + ] return features @classmethod @@ -819,8 +810,7 @@ def compute(cls, data, height): Derived feature array """ - vort = vorticity_calc(data[f'U_{height}m'], - data[f'V_{height}m']) + vort = vorticity_calc(data[f'U_{height}m'], data[f'V_{height}m']) return vort @@ -843,8 +833,9 @@ def inputs(cls, feature): List of required features for computing V """ height = Feature.get_height(feature) - features = [f'windspeed_{height}m', f'winddirection_{height}m', - 'lat_lon'] + features = [ + f'windspeed_{height}m', f'winddirection_{height}m', 'lat_lon' + ] return features @classmethod @@ -990,9 +981,9 @@ def compute(file_paths, raster_index): handle = xr.open_dataset(fp) valid_vars = set(handle.variables) lat_key = {'XLAT', 'lat', 'latitude'}.intersection(valid_vars) - lat_key = list(lat_key)[0] + lat_key = next(iter(lat_key)) lon_key = {'XLONG', 'lon', 'longitude'}.intersection(valid_vars) - lon_key = list(lon_key)[0] + lon_key = next(iter(lon_key)) if len(handle.variables[lat_key].dims) == 4: idx = (0, raster_index[0], raster_index[1], 0) @@ -1005,8 +996,8 @@ def compute(file_paths, raster_index): lons = handle.variables[lon_key].values lats = handle.variables[lat_key].values lons, lats = np.meshgrid(lons, lats) - lat_lon = np.dstack((lats[tuple(raster_index)], - lons[tuple(raster_index)])) + lat_lon = np.dstack( + (lats[tuple(raster_index)], lons[tuple(raster_index)])) else: lats = handle.variables[lat_key].values[idx] lons = handle.variables[lon_key].values[idx] @@ -1065,8 +1056,8 @@ def compute(file_paths, raster_index): """ with Resource(file_paths[0], hsds=False) as handle: lat_lon = handle.lat_lon[tuple([raster_index.flatten()])] - lat_lon = lat_lon.reshape((raster_index.shape[0], - raster_index.shape[1], 2)) + lat_lon = lat_lon.reshape( + (raster_index.shape[0], raster_index.shape[1], 2)) return lat_lon @@ -1188,8 +1179,9 @@ def valid_handle_features(cls, features, handle_features): if features is None: return False - return all(Feature.get_basename(f) in handle_features - or f in handle_features for f in features) + return all( + Feature.get_basename(f) in handle_features or f in handle_features + for f in features) @classmethod def valid_input_features(cls, features, handle_features): @@ -1211,9 +1203,10 @@ def valid_input_features(cls, features, handle_features): if features is None: return False - if all(Feature.get_basename(f) in handle_features - or f in handle_features - or cls.lookup(f, 'compute') is not None for f in features): + if all( + Feature.get_basename(f) in handle_features + or f in handle_features or cls.lookup(f, 'compute') is not None + for f in features): return True return False @@ -1235,8 +1228,7 @@ def pop_old_data(cls, data, chunk_number, all_features): """ if data: - old_keys = [f for f in data[chunk_number] - if f not in all_features] + old_keys = [f for f in data[chunk_number] if f not in all_features] for k in old_keys: data[chunk_number].pop(k) @@ -1282,8 +1274,13 @@ def serial_extract(cls, file_paths, raster_index, time_chunks, return data @classmethod - def parallel_extract(cls, file_paths, raster_index, time_chunks, - input_features, max_workers=None, **kwargs): + def parallel_extract(cls, + file_paths, + raster_index, + time_chunks, + input_features, + max_workers=None, + **kwargs): """Extract features using parallel subprocesses Parameters @@ -1382,7 +1379,8 @@ def recursive_compute(cls, data, feature, handle_features, file_paths, Array of computed feature data """ if feature not in data: - inputs = cls.lookup(feature, 'inputs', + inputs = cls.lookup(feature, + 'inputs', handle_features=handle_features) method = cls.lookup(feature, 'compute') height = Feature.get_height(feature) @@ -1393,10 +1391,8 @@ def recursive_compute(cls, data, feature, handle_features, file_paths, data[feature] = method(data, height) else: for r in inputs(feature): - data[r] = cls.recursive_compute(data, r, - handle_features, - file_paths, - raster_index) + data[r] = cls.recursive_compute( + data, r, handle_features, file_paths, raster_index) data[feature] = method(data, height) elif method is not None: data[feature] = method(file_paths, raster_index) @@ -1448,8 +1444,11 @@ def serial_compute(cls, data, file_paths, raster_index, time_chunks, for _, f in enumerate(derived_features): tmp = cls.get_input_arrays(data, t, f, handle_features) data[t][f] = cls.recursive_compute( - data=tmp, feature=f, handle_features=handle_features, - file_paths=file_paths, raster_index=raster_index) + data=tmp, + feature=f, + handle_features=handle_features, + file_paths=file_paths, + raster_index=raster_index) cls.pop_old_data(data, t, all_features) interval = int(np.ceil(len(time_chunks) / 10)) if t % interval == 0: @@ -1459,8 +1458,14 @@ def serial_compute(cls, data, file_paths, raster_index, time_chunks, return data @classmethod - def parallel_compute(cls, data, file_paths, raster_index, time_chunks, - derived_features, all_features, handle_features, + def parallel_compute(cls, + data, + file_paths, + raster_index, + time_chunks, + derived_features, + all_features, + handle_features, max_workers=None): """Compute features using parallel subprocesses @@ -1507,7 +1512,8 @@ def parallel_compute(cls, data, file_paths, raster_index, time_chunks, for t, _ in enumerate(time_chunks): for f in derived_features: tmp = cls.get_input_arrays(data, t, f, handle_features) - future = exe.submit(cls.recursive_compute, data=tmp, + future = exe.submit(cls.recursive_compute, + data=tmp, feature=f, handle_features=handle_features, file_paths=file_paths, @@ -1711,9 +1717,8 @@ def get_inputs_recursive(cls, feature, handle_features): lower_handle_features = [f.lower() for f in handle_features] check1 = feature not in raw_features - check2 = (cls.valid_handle_features([feature.lower()], - lower_handle_features) - or method is None) + check2 = (cls.valid_handle_features( + [feature.lower()], lower_handle_features) or method is None) if check1 and check2: raw_features.append(feature) @@ -1775,8 +1780,12 @@ def feature_registry(cls): @classmethod @abstractmethod - def extract_feature(cls, file_paths, raster_index, feature, - time_slice=slice(None), **kwargs): + def extract_feature(cls, + file_paths, + raster_index, + feature, + time_slice=slice(None), + **kwargs): """Extract single feature from data source Parameters diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index de0080ebf..c6e6dc2ca 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -8,11 +8,9 @@ import logging import os from calendar import monthrange -from concurrent.futures import ( - ProcessPoolExecutor, - ThreadPoolExecutor, - as_completed, -) +from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor, + as_completed, + ) from glob import glob from typing import ClassVar from warnings import warn @@ -46,66 +44,34 @@ class EraDownloader: assert os.path.exists(req_file), msg VALID_VARIABLES: ClassVar[list] = [ - 'u', - 'v', - 'pressure', - 'temperature', - 'relative_humidity', - 'specific_humidity', - 'total_precipitation', + 'u', 'v', 'pressure', 'temperature', 'relative_humidity', + 'specific_humidity', 'total_precipitation', ] - KEEP_VARIABLES: ClassVar[list] = [ - 'orog', - 'time', - 'latitude', - 'longitude', - ] + KEEP_VARIABLES: ClassVar[list] = ['orog'] KEEP_VARIABLES += [f'{v}_' for v in VALID_VARIABLES] DEFAULT_RENAMED_VARS: ClassVar[list] = [ - 'zg', - 'orog', - 'u', - 'v', - 'u_10m', - 'v_10m', - 'u_100m', - 'v_100m', - 'temperature', - 'pressure', + 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', + 'temperature', 'pressure', ] DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ - '10m_u_component_of_wind', - '10m_v_component_of_wind', - '100m_u_component_of_wind', - '100m_v_component_of_wind', - 'u_component_of_wind', - 'v_component_of_wind', - '2m_temperature', - 'temperature', - 'surface_pressure', - 'relative_humidity', + '10m_u_component_of_wind', '10m_v_component_of_wind', + '100m_u_component_of_wind', '100m_v_component_of_wind', + 'u_component_of_wind', 'v_component_of_wind', '2m_temperature', + 'temperature', 'surface_pressure', 'relative_humidity', 'total_precipitation', ] SFC_VARS: ClassVar[list] = [ - '10m_u_component_of_wind', - '10m_v_component_of_wind', - '100m_u_component_of_wind', - '100m_v_component_of_wind', - 'surface_pressure', - '2m_temperature', - 'geopotential', + '10m_u_component_of_wind', '10m_v_component_of_wind', + '100m_u_component_of_wind', '100m_v_component_of_wind', + 'surface_pressure', '2m_temperature', 'geopotential', 'total_precipitation', ] LEVEL_VARS: ClassVar[list] = [ - 'u_component_of_wind', - 'v_component_of_wind', - 'geopotential', - 'temperature', - 'relative_humidity', - 'specific_humidity', + 'u_component_of_wind', 'v_component_of_wind', 'geopotential', + 'temperature', 'relative_humidity', 'specific_humidity', ] NAME_MAP: ClassVar[dict] = { 'u10': 'u_10m', @@ -119,7 +85,7 @@ class EraDownloader: 'sp': 'pressure_0m', 'r': 'relative_humidity', 'q': 'specific_humidity', - 'tp': 'total_precip', + 'tp': 'total_precipitation', } def __init__(self, @@ -418,11 +384,10 @@ def map_vars(self, old_ds, ds): """ for old_name, new_name in self.NAME_MAP.items(): if old_name in old_ds.variables: - _ = ds.createVariable( - new_name, - np.float32, - dimensions=old_ds[old_name].dimensions, - ) + _ = ds.createVariable(new_name, + np.float32, + dimensions=old_ds[old_name].dimensions, + ) vals = old_ds.variables[old_name][:] if 'temperature' in new_name: vals -= 273.15 @@ -528,6 +493,7 @@ def good_file(self, file, required_shape): Whether or not data has required shape and variables. """ out = self.check_single_file(file, + var_list=self.variables, check_nans=False, check_heights=False, required_shape=required_shape) @@ -896,11 +862,9 @@ def _check_single_file(cls, Percent of data which consists of NaNs across all given variables. """ good_vars = all(var in res for var in var_list) - res_shape = ( - *res['level'].shape, - *res['latitude'].shape, - *res['longitude'].shape, - ) + res_shape = (*res['level'].shape, *res['latitude'].shape, + *res['longitude'].shape, + ) good_shape = ('NA' if required_shape is None else (res_shape == required_shape)) good_hgts = ('NA' if not check_heights else cls.check_heights( @@ -912,8 +876,8 @@ def _check_single_file(cls, res, var_list=var_list)) if not good_vars: - mask = np.array([var not in res for var in var_list]) - missing_vars = var_list[mask] + mask = [var not in res for var in var_list] + missing_vars = np.array(var_list)[mask] logger.error(f'Missing variables: {missing_vars}.') if good_shape != 'NA' and not good_shape: logger.error(f'Bad shape: {res_shape} != {required_shape}.') @@ -961,11 +925,10 @@ def check_heights(cls, res, max_interp_height=200, max_workers=10): futures = [] with ProcessPoolExecutor(max_workers=max_workers) as exe: for idt in range(heights.shape[0]): - future = exe.submit( - cls._check_heights_single_ts, - heights[idt], - max_interp_height=max_interp_height, - ) + future = exe.submit(cls._check_heights_single_ts, + heights[idt], + max_interp_height=max_interp_height, + ) futures.append(future) msg = (f'Submitted height check for {idt + 1} of ' f'{heights.shape[0]}') diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index a72a12c37..e74bb5a0b 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -169,6 +169,7 @@ def test_st_dual_batch_handler(log=False, s_enhance=s_enhance, t_enhance=t_enhance, n_batches=10) + assert np.allclose(batch_handler.handler_weights, 0.5) for batch in batch_handler: From 6e84e7a5e7a8329cf17bca4a512eaf99d847de27 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 12 Sep 2023 15:14:42 -0600 Subject: [PATCH 8/9] renamed get_current_handler -> get_rand_handler. changed get_closest_lat_lon to use min(dist) instead of kdtree. kept as static method since it is called by class methods. --- sup3r/preprocessing/batch_handling.py | 20 ++- sup3r/preprocessing/data_handling/base.py | 176 +++++++++++----------- 2 files changed, 100 insertions(+), 96 deletions(-) diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index 266a90c54..9e62e0f8f 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -536,7 +536,7 @@ def get_handler_index(self): indices = np.arange(0, len(self.data_handlers)) return np.random.choice(indices, p=self.handler_weights) - def get_current_handler(self): + def get_rand_handler(self): """Get random handler based on handler weights""" self.current_handler_index = self.get_handler_index() return self.data_handlers[self.current_handler_index] @@ -952,7 +952,7 @@ def __next__(self): """ self.current_batch_indices = [] if self._i < self.n_batches: - handler = self.get_current_handler() + handler = self.get_rand_handler() high_res = np.zeros( (self.batch_size, self.sample_shape[0], self.sample_shape[1], self.sample_shape[2], self.shape[-1]), @@ -1018,7 +1018,7 @@ def __next__(self): if self._i >= self.n_batches: raise StopIteration - handler = self.get_current_handler() + handler = self.get_rand_handler() low_res = None high_res = None @@ -1124,7 +1124,7 @@ def __next__(self): if self._i >= self.n_batches: raise StopIteration - handler = self.get_current_handler() + handler = self.get_rand_handler() high_res = None @@ -1178,7 +1178,7 @@ class SpatialBatchHandler(BatchHandler): def __next__(self): if self._i < self.n_batches: - handler = self.get_current_handler() + handler = self.get_rand_handler() high_res = np.zeros((self.batch_size, self.sample_shape[0], self.sample_shape[1], self.shape[-1]), dtype=np.float32) @@ -1255,10 +1255,8 @@ def _get_val_indices(self): np.arange(h.data.shape[-1]) ]) val_indices[s + self.N_TIME_BINS].append({ - 'handler_index': - h_idx, - 'tuple_index': - tuple_index + 'handler_index': h_idx, + 'tuple_index': tuple_index }) return val_indices @@ -1370,7 +1368,7 @@ def __iter__(self): def __next__(self): self.current_batch_indices = [] if self._i < self.n_batches: - handler = self.get_current_handler() + handler = self.get_rand_handler() high_res = np.zeros( (self.batch_size, self.sample_shape[0], self.sample_shape[1], self.sample_shape[2], self.shape[-1]), @@ -1458,7 +1456,7 @@ def __iter__(self): def __next__(self): self.current_batch_indices = [] if self._i < self.n_batches: - handler = self.get_current_handler() + handler = self.get_rand_handler() high_res = np.zeros((self.batch_size, self.sample_shape[0], self.sample_shape[1], self.shape[-1], ), diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index 6aff69635..e63a4a13f 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -15,23 +15,36 @@ from rex import Resource from rex.utilities import log_mem from rex.utilities.fun_utils import get_fun_call_str -from scipy.spatial import KDTree from sup3r.bias.bias_transforms import get_spatial_bc_factors from sup3r.preprocessing.data_handling.mixin import (InputMixIn, TrainingPrepMixIn, ) -from sup3r.preprocessing.feature_handling import ( - BVFreqMon, BVFreqSquaredNC, Feature, FeatureHandler, InverseMonNC, - LatLonNC, PotentialTempNC, PressureNC, Rews, Shear, TempNC, UWind, VWind, - WinddirectionNC, WindspeedNC, -) +from sup3r.preprocessing.feature_handling import (BVFreqMon, + BVFreqSquaredNC, + Feature, + FeatureHandler, + InverseMonNC, + LatLonNC, + PotentialTempNC, + PressureNC, + Rews, + Shear, + TempNC, + UWind, + VWind, + WinddirectionNC, + WindspeedNC, + ) from sup3r.utilities import ModuleName from sup3r.utilities.cli import BaseCLI from sup3r.utilities.interpolation import Interpolator -from sup3r.utilities.utilities import (estimate_max_workers, get_chunk_slices, - get_raster_shape, np_to_pd_times, - spatial_coarsening, uniform_box_sampler, +from sup3r.utilities.utilities import (estimate_max_workers, + get_chunk_slices, + get_raster_shape, + np_to_pd_times, + spatial_coarsening, + uniform_box_sampler, uniform_time_sampler, weighted_box_sampler, weighted_time_sampler, @@ -54,7 +67,9 @@ class DataHandler(FeatureHandler, InputMixIn, TrainingPrepMixIn): # model but are not part of the synthetic output and are not sent to the # discriminator. These are case-insensitive and follow the Unix shell-style # wildcard format. - TRAIN_ONLY_FEATURES = ('BVF*', 'inversemoninobukhovlength_*', 'RMOL', + TRAIN_ONLY_FEATURES = ('BVF*', + 'inversemoninobukhovlength_*', + 'RMOL', 'topography', ) @@ -246,8 +261,11 @@ def __init__(self, self._load_workers = self.worker_kwargs.get('load_workers', None) self._compute_workers = self.worker_kwargs.get('compute_workers', None) self._worker_attrs = [ - '_ti_workers', '_norm_workers', '_compute_workers', - '_extract_workers', '_load_workers' + '_ti_workers', + '_norm_workers', + '_compute_workers', + '_extract_workers', + '_load_workers' ] self.preflight() @@ -298,7 +316,8 @@ def __init__(self, @property def try_load(self): """Check if we should try to load cache""" - return self._should_load_cache(self._cache_pattern, self.cache_files, + return self._should_load_cache(self._cache_pattern, + self.cache_files, self.overwrite_cache) def check_clear_data(self): @@ -339,8 +358,9 @@ def _val_split_check(self): msg = (f'Validation data has shape={self.val_data.shape} ' f'and sample_shape={self.sample_shape}. Use a smaller ' 'sample_shape and/or larger val_split.') - check = any(val_size < samp_size for val_size, samp_size in zip( - self.val_data.shape, self.sample_shape)) + check = any( + val_size < samp_size for val_size, + samp_size in zip(self.val_data.shape, self.sample_shape)) if check: logger.warning(msg) warnings.warn(msg) @@ -394,7 +414,8 @@ def extract_workers(self): proc_mem /= len(self.time_chunks) n_procs = len(self.time_chunks) * len(self.extract_features) n_procs = int(np.ceil(n_procs)) - extract_workers = estimate_max_workers(self._extract_workers, proc_mem, + extract_workers = estimate_max_workers(self._extract_workers, + proc_mem, n_procs) return extract_workers @@ -410,7 +431,8 @@ def compute_workers(self): proc_mem /= len(self.time_chunks) n_procs = len(self.time_chunks) * len(self.derive_features) n_procs = int(np.ceil(n_procs)) - compute_workers = estimate_max_workers(self._compute_workers, proc_mem, + compute_workers = estimate_max_workers(self._compute_workers, + proc_mem, n_procs) return compute_workers @@ -422,7 +444,8 @@ def load_workers(self): n_procs = 1 if self.cache_files is not None: n_procs = len(self.cache_files) - load_workers = estimate_max_workers(self._load_workers, proc_mem, + load_workers = estimate_max_workers(self._load_workers, + proc_mem, n_procs) return load_workers @@ -689,13 +712,9 @@ def get_closest_lat_lon(lat_lon, target): col : int col index for closest lat/lon to target lat/lon """ - # shape of ll2 is (n, 2) where axis=1 is (lat, lon) - ll2 = np.vstack((lat_lon[..., 0].flatten(), lat_lon[..., - 1].flatten())).T - tree = KDTree(ll2) - _, i = tree.query(np.array(target)) - row, col = np.where((lat_lon[..., 0] == ll2[i, 0]) - & (lat_lon[..., 1] == ll2[i, 1])) + dist = np.hypot(lat_lon[..., 0] - target[0], + lat_lon[..., 1] - target[1]) + row, col = np.where(dist == np.min(dist)) row = row[0] col = col[0] return row, col @@ -831,8 +850,11 @@ def get_cache_file_names(self, target = target if target is not None else self.target features = features if features is not None else self.features - return self._get_cache_file_names(cache_pattern, grid_shape, - time_index, target, features) + return self._get_cache_file_names(cache_pattern, + grid_shape, + time_index, + target, + features) def unnormalize(self, means, stds): """Remove normalization from stored means and stds""" @@ -941,7 +963,9 @@ def cache_data(self, cache_file_paths): cache_file_paths : str | None Path to file for saving feature data """ - self._cache_data(self.data, self.features, cache_file_paths, + self._cache_data(self.data, + self.features, + cache_file_paths, self.overwrite_cache) @property @@ -968,8 +992,10 @@ def load_cached_data(self, with_split=True): elif self.data is None: msg = ('Found {} cache files but need {} for features {}! ' 'These are the cache files that were found: {}'.format( - len(self.cache_files), len(self.features), - self.features, self.cache_files)) + len(self.cache_files), + len(self.features), + self.features, + self.cache_files)) assert len(self.cache_files) == len(self.features), msg self.data = np.full(shape=self.requested_shape, @@ -1083,17 +1109,23 @@ def run_data_compute(self): logger.info(f'Starting computation of {self.derive_features}') if self.compute_workers == 1: - self._raw_data = self.serial_compute( - self._raw_data, self.file_paths, self.raster_index, - self.time_chunks, self.derive_features, - self.noncached_features, self.handle_features) + self._raw_data = self.serial_compute(self._raw_data, + self.file_paths, + self.raster_index, + self.time_chunks, + self.derive_features, + self.noncached_features, + self.handle_features) elif self.compute_workers != 1: - self._raw_data = self.parallel_compute( - self._raw_data, self.file_paths, self.raster_index, - self.time_chunks, self.derive_features, - self.noncached_features, self.handle_features, - self.compute_workers) + self._raw_data = self.parallel_compute(self._raw_data, + self.file_paths, + self.raster_index, + self.time_chunks, + self.derive_features, + self.noncached_features, + self.handle_features, + self.compute_workers) logger.info(f'Finished computing {self.derive_features} for ' f'{self.input_file_info}') @@ -1149,10 +1181,11 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None): max available workers will be used. If 1 cached data will be loaded in serial """ - self.data = np.zeros( - (self.grid_shape[0], self.grid_shape[1], self.n_tsteps, - len(self.features)), - dtype=np.float32) + self.data = np.zeros((self.grid_shape[0], + self.grid_shape[1], + self.n_tsteps, + len(self.features)), + dtype=np.float32) if max_workers == 1: self.serial_data_fill(shifted_time_chunks) @@ -1298,8 +1331,7 @@ def get_observation_index(self, self.sample_shape[2]) return tuple( - [*spatial_slice, temporal_slice, - np.arange(len(self.features))]) + [*spatial_slice, temporal_slice, np.arange(len(self.features))]) def get_next(self, temporal_weights=None, spatial_weights=None): """Get data for observation using weighted random observation index. @@ -1486,19 +1518,27 @@ def extract_feature(cls, if feature in handle or feature.lower() in handle: feat_key = feature if feature in handle else feature.lower() - fdata = cls.direct_extract(handle, feat_key, raster_index, + fdata = cls.direct_extract(handle, + feat_key, + raster_index, time_slice) elif basename in handle or basename.lower() in handle: feat_key = basename if basename in handle else basename.lower() if interp_height is not None: fdata = Interpolator.interp_var_to_height( - handle, feat_key, raster_index, np.float32(interp_height), + handle, + feat_key, + raster_index, + np.float32(interp_height), time_slice) elif interp_pressure is not None: fdata = Interpolator.interp_var_to_pressure( - handle, feat_key, raster_index, - np.float32(interp_pressure), time_slice) + handle, + feat_key, + raster_index, + np.float32(interp_pressure), + time_slice) else: msg = f'{feature} cannot be extracted from source data.' @@ -1561,38 +1601,6 @@ def get_full_domain(cls, file_paths): """ return cls.get_lat_lon(file_paths, [slice(None), slice(None)]) - @staticmethod - def get_closest_lat_lon(lat_lon, target): - """Get closest indices to target lat lon to use for lower left corner - of raster index - - Parameters - ---------- - lat_lon : ndarray - Array of lat/lon - (spatial_1, spatial_2, 2) - Last dimension in order of (lat, lon) - target : tuple - (lat, lon) for lower left corner - - Returns - ------- - row : int - row index for closest lat/lon to target lat/lon - col : int - col index for closest lat/lon to target lat/lon - """ - # shape of ll2 is (n, 2) where axis=1 is (lat, lon) - ll2 = np.vstack((lat_lon[..., 0].flatten(), lat_lon[..., - 1].flatten())).T - tree = KDTree(ll2) - _, i = tree.query(np.array(target)) - row, col = np.where((lat_lon[..., 0] == ll2[i, 0]) - & (lat_lon[..., 1] == ll2[i, 1])) - row = row[0] - col = col[0] - return row, col - @classmethod def compute_raster_index(cls, file_paths, target, grid_shape): """Get raster index for a given target and shape @@ -1611,8 +1619,7 @@ def compute_raster_index(cls, file_paths, target, grid_shape): list List of slices corresponding to extracted data region """ - lat_lon = cls.get_lat_lon(file_paths[:1], - [slice(None), slice(None)], + lat_lon = cls.get_lat_lon(file_paths[:1], [slice(None), slice(None)], invert_lat=False) cls._check_grid_extent(target, grid_shape, lat_lon) @@ -1632,8 +1639,7 @@ def compute_raster_index(cls, file_paths, target, grid_shape): row_end = row + grid_shape[0] row_start = row raster_index = [ - slice(row_start, row_end), - slice(col, col + grid_shape[1]) + slice(row_start, row_end), slice(col, col + grid_shape[1]) ] cls._validate_raster_shape(target, grid_shape, lat_lon, raster_index) return raster_index From 5ec22061755950f3c17981df85c5e9bfba173b9f Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 14 Sep 2023 09:52:12 -0600 Subject: [PATCH 9/9] bc removed from dual_data_handler in favor of performing it on lr_handler before sending through dual_data_handler. --- .../data_handling/dual_data_handling.py | 23 +------------------ 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 2e7a7cf3e..438cb8934 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -6,8 +6,7 @@ import pandas as pd from sup3r.preprocessing.data_handling.mixin import (CacheHandlingMixIn, - TrainingPrepMixIn, - ) + TrainingPrepMixIn) from sup3r.utilities.regridder import Regridder from sup3r.utilities.utilities import spatial_coarsening @@ -29,8 +28,6 @@ def __init__(self, shuffle_time=False, s_enhance=15, t_enhance=1, - bc_files=None, - bc_threshold=0.1, val_split=0.0): """Initialize data handler using hr and lr data handlers for h5 data and nc data @@ -56,19 +53,6 @@ def __init__(self, Spatial enhancement factor t_enhance : int Temporal enhancement factor - bc_files : list | tuple | str | None - One or more filepaths to .h5 files output by - MonthlyLinearCorrection or LinearCorrection. Used to bias correct - low resolution data prior to regrdding. These should contain - datasets named "{feature}_scalar" and "{feature}_adder" where - {feature} is one of the features contained by this DataHandler and - the data is a 3D array of shape (lat, lon, time) where time is - length 1 for annual correction or 12 for monthly correction. Bias - correction is only run if bc_files is not None. - bc_threshold : float - Nearest neighbor euclidean distance threshold. If the DataHandler - coordinates are more than this value away from the bias correction - lat/lon, an error is raised. val_split : float Percentage of data to reserve for validation. """ @@ -94,8 +78,6 @@ def __init__(self, self.hr_time_index = None self.lr_val_time_index = None self.hr_val_time_index = None - self.bc_files = bc_files - self.bc_threshold = bc_threshold if self.try_load and self.load_cached: self.load_cached_data() @@ -315,9 +297,6 @@ def lr_input_data(self): if self._lr_input_data is None: if self.lr_dh.data is None: self.lr_dh.load_cached_data() - if self.bc_files is not None: - logger.info('Running bias correction on low resolution data.') - self.lr_dh.lin_bc(self.bc_files, self.bc_threshold) self._lr_input_data = self.lr_dh.data[ ..., :self.lr_required_shape[2], :] return self._lr_input_data