diff --git a/.gitignore b/.gitignore index 29385ddab..af6cd09a0 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,7 @@ nosetests.xml coverage.xml *.cover .hypothesis/ +*.png # Translations *.mo diff --git a/requirements.txt b/requirements.txt index 1242e486e..2da8f8c9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,9 @@ matplotlib>=3.1 NREL-rex>=0.2.82 NREL-phygnn>=0.0.23 NREL-rev<0.8.0 +NREL-gaps>=0.4.0 NREL-farms>=1.0.4 +google-auth-oauthlib==0.5.3 pytest>=5.2 pillow tensorflow>2.4 @@ -11,3 +13,4 @@ netCDF4==1.5.8 dask sphinx pandas +numpy==1.22 diff --git a/sup3r/batch/batch.py b/sup3r/batch/batch.py index 843ee322b..f379d9be9 100644 --- a/sup3r/batch/batch.py +++ b/sup3r/batch/batch.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- """sup3r batch utilities based on reV's batch module""" +from gaps.legacy import BatchJob as GapsBatchJob + from sup3r.pipeline.pipeline import Sup3rPipeline from sup3r.pipeline.pipeline_cli import pipeline_monitor_background -from reV.batch.batch import BatchJob as RevBatchJob -class BatchJob(RevBatchJob): +class BatchJob(GapsBatchJob): """Framework for building a batched job suite.""" # Class attributes to set the software's pipeline class and run+monitor in diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 3dc4fd7a2..fa4c0567b 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -13,9 +13,9 @@ import pandas as pd import rex from rex.utilities.fun_utils import get_fun_call_str +from scipy import stats from scipy.ndimage.filters import gaussian_filter from scipy.spatial import KDTree -from scipy import stats import sup3r.preprocessing.data_handling from sup3r.utilities import VERSION_RECORD, ModuleName @@ -30,10 +30,17 @@ class DataRetrievalBase: baseline data """ - def __init__(self, base_fps, bias_fps, base_dset, bias_feature, - target=None, shape=None, base_handler='Resource', + def __init__(self, + base_fps, + bias_fps, + base_dset, + bias_feature, + target=None, + shape=None, + base_handler='Resource', bias_handler='DataHandlerNCforCC', - bias_handler_kwargs=None, decimals=None): + bias_handler_kwargs=None, + decimals=None): """ Parameters ---------- @@ -75,8 +82,8 @@ def __init__(self, base_fps, bias_fps, base_dset, bias_feature, """ logger.info('Initializing DataRetrievalBase for base dset "{}" ' - 'correcting biased dataset(s): {}' - .format(base_dset, bias_feature)) + 'correcting biased dataset(s): {}'.format( + base_dset, bias_feature)) self.base_fps = base_fps self.bias_fps = bias_fps self.base_dset = base_dset @@ -100,9 +107,10 @@ def __init__(self, base_fps, bias_fps, base_dset, bias_feature, self.base_tree = KDTree(self.base_meta[['latitude', 'longitude']]) self.bias_dh = self.bias_handler(self.bias_fps, [self.bias_feature], - target=self.target, shape=self.shape, - val_split=0.0, **bias_handler_kwargs) - + target=self.target, + shape=self.shape, + val_split=0.0, + **bias_handler_kwargs) lats = self.bias_dh.lat_lon[..., 0].flatten() lons = self.bias_dh.lat_lon[..., 1].flatten() self.bias_meta = pd.DataFrame({'latitude': lats, 'longitude': lons}) @@ -116,8 +124,8 @@ def __init__(self, base_fps, bias_fps, base_dset, bias_feature, @property def meta(self): - """Get a meta data dictionary on how these bias factors were calculated - """ + """Get a meta data dictionary on how these bias factors were + calculated""" meta = {'base_fps': self.base_fps, 'bias_fps': self.bias_fps, 'base_dset': self.base_dset, @@ -125,8 +133,7 @@ def meta(self): 'target': self.target, 'shape': self.shape, 'class': str(self.__class__), - 'version_record': VERSION_RECORD, - } + 'version_record': VERSION_RECORD} return meta @staticmethod @@ -167,7 +174,7 @@ def get_node_cmd(cls, config): initialize the class and call run() on a single node. """ import_str = 'import time;\n' - import_str += 'from reV.pipeline.status import Status;\n' + import_str += 'from sup3r.pipeline import Status;\n' import_str += 'from rex import init_logger;\n' import_str += f'from sup3r.bias.bias_calc import {cls.__name__};\n' @@ -185,7 +192,7 @@ def get_node_cmd(cls, config): log_file = config.get('log_file', None) log_level = config.get('log_level', 'INFO') - log_arg_str = (f'"sup3r", log_level="{log_level}"') + log_arg_str = f'"sup3r", log_level="{log_level}"' if log_file is not None: log_arg_str += f', log_file="{log_file}"' @@ -197,7 +204,7 @@ def get_node_cmd(cls, config): "t_elap = time.time() - t0;\n") cmd = BaseCLI.add_status_cmd(config, ModuleName.BIAS_CALC, cmd) - cmd += (";\'\n") + cmd += ";\'\n" return cmd.replace('\\', '/') @@ -278,7 +285,9 @@ def get_data_pair(self, coord, knn, daily_reduction='avg'): bias_gid, bias_dist = self.get_bias_gid(coord) base_dist, base_gid = self.get_base_gid(bias_gid, knn) bias_data = self.get_bias_data(bias_gid) - base_data = self.get_base_data(self.base_fps, self.base_dset, base_gid, + base_data = self.get_base_data(self.base_fps, + self.base_dset, + base_gid, self.base_handler, daily_reduction=daily_reduction, decimals=self.decimals) @@ -309,8 +318,8 @@ def get_bias_data(self, bias_gid): bias_data = bias_data[:, 0] else: msg = ('Found a weird number of feature channels for the bias ' - 'data retrieval: {}. Need just one channel' - .format(bias_data.shape)) + 'data retrieval: {}. Need just one channel'.format( + bias_data.shape)) logger.error(msg) raise RuntimeError(msg) @@ -320,8 +329,13 @@ def get_bias_data(self, bias_gid): return bias_data @classmethod - def get_base_data(cls, base_fps, base_dset, base_gid, base_handler, - daily_reduction='avg', decimals=None): + def get_base_data(cls, + base_fps, + base_dset, + base_gid, + base_handler, + daily_reduction='avg', + decimals=None): """Get data from the baseline data source, possibly for many high-res base gids corresponding to a single coarse low-res bias gid. @@ -364,12 +378,16 @@ def get_base_data(cls, base_fps, base_dset, base_gid, base_handler, with base_handler(fp) as res: base_ti = res.time_index - base_data, base_cs_ghi = cls._read_base_data(res, base_dset, - base_gid) + base_data, base_cs_ghi = cls._read_base_data( + res, base_dset, base_gid) if daily_reduction is not None: - base_data = cls._reduce_base_data(base_ti, base_data, - base_cs_ghi, base_dset, - daily_reduction) + base_data = cls._reduce_base_data( + base_ti, + base_data, + base_cs_ghi, + base_dset, + daily_reduction, + ) base_ti = np.array(sorted(set(base_ti.date))) out.append(base_data) @@ -467,12 +485,14 @@ def _reduce_base_data(base_ti, base_data, base_cs_ghi, base_dset, if daily_reduction is None: return base_data - slices = [np.where(base_ti.date == date) - for date in sorted(set(base_ti.date))] + slices = [ + np.where(base_ti.date == date) + for date in sorted(set(base_ti.date)) + ] if base_dset == 'clearsky_ratio' and daily_reduction.lower() == 'avg': - base_data = np.array([base_data[s0].sum() / base_cs_ghi[s0].sum() - for s0 in slices]) + base_data = np.array( + [base_data[s0].sum() / base_cs_ghi[s0].sum() for s0 in slices]) elif daily_reduction.lower() == 'avg': base_data = np.array([base_data[s0].mean() for s0 in slices]) @@ -551,34 +571,47 @@ def get_linear_correction(bias_data, base_data, bias_feature, base_dset): scalar = np.nanstd(base_data) / bias_std adder = np.nanmean(base_data) - np.nanmean(bias_data) * scalar - out = {f'bias_{bias_feature}_mean': np.nanmean(bias_data), - f'bias_{bias_feature}_std': bias_std, - f'base_{base_dset}_mean': np.nanmean(base_data), - f'base_{base_dset}_std': np.nanstd(base_data), - f'{bias_feature}_scalar': scalar, - f'{bias_feature}_adder': adder, - } + out = { + f'bias_{bias_feature}_mean': np.nanmean(bias_data), + f'bias_{bias_feature}_std': bias_std, + f'base_{base_dset}_mean': np.nanmean(base_data), + f'base_{base_dset}_std': np.nanstd(base_data), + f'{bias_feature}_scalar': scalar, + f'{bias_feature}_adder': adder, + } return out # pylint: disable=W0613 @classmethod - def _run_single(cls, bias_data, base_fps, bias_feature, base_dset, - base_gid, base_handler, daily_reduction, bias_ti, + def _run_single(cls, + bias_data, + base_fps, + bias_feature, + base_dset, + base_gid, + base_handler, + daily_reduction, + bias_ti, decimals): """Find the nominal scalar + adder combination to bias correct data at a single site""" - base_data, _ = cls.get_base_data(base_fps, base_dset, - base_gid, base_handler, + base_data, _ = cls.get_base_data(base_fps, + base_dset, + base_gid, + base_handler, daily_reduction=daily_reduction, decimals=decimals) - out = cls.get_linear_correction(bias_data, base_data, - bias_feature, base_dset) + out = cls.get_linear_correction(bias_data, base_data, bias_feature, + base_dset) return out - def fill_and_smooth(self, out, fill_extend=True, smooth_extend=0, + def fill_and_smooth(self, + out, + fill_extend=True, + smooth_extend=0, smooth_interior=0): """Fill data extending beyond the base meta data extent by doing a nearest neighbor gap fill. Smooth interior and extended region with @@ -624,8 +657,8 @@ def fill_and_smooth(self, out, fill_extend=True, smooth_extend=0, arr_smooth = arr[..., idt] - needs_fill = ((np.isnan(arr_smooth).any() and fill_extend) - or smooth_interior > 0) + needs_fill = (np.isnan(arr_smooth).any() + and fill_extend) or smooth_interior > 0 if needs_fill: arr_smooth = nn_fill_array(arr_smooth) @@ -677,11 +710,17 @@ def write_outputs(self, fp_out, out): for k, v in self.meta.items(): f.attrs[k] = json.dumps(v) - logger.info('Wrote scalar adder factors to file: {}' - .format(fp_out)) - - def run(self, knn, threshold=0.6, fp_out=None, max_workers=None, - daily_reduction='avg', fill_extend=True, smooth_extend=0, + logger.info( + 'Wrote scalar adder factors to file: {}'.format(fp_out)) + + def run(self, + knn, + threshold=0.6, + fp_out=None, + max_workers=None, + daily_reduction='avg', + fill_extend=True, + smooth_extend=0, smooth_interior=0): """Run linear correction factor calculations for every site in the bias dataset @@ -740,12 +779,17 @@ def run(self, knn, threshold=0.6, fp_out=None, max_workers=None, if np.mean(dist) < threshold: bias_data = self.get_bias_data(bias_gid) - single_out = self._run_single(bias_data, self.base_fps, - self.bias_feature, - self.base_dset, base_gid, - self.base_handler, - daily_reduction, - self.bias_ti, self.decimals) + single_out = self._run_single( + bias_data, + self.base_fps, + self.bias_feature, + self.base_dset, + base_gid, + self.base_handler, + daily_reduction, + self.bias_ti, + self.decimals, + ) for key, arr in single_out.items(): self.out[key][raster_loc] = arr @@ -753,8 +797,9 @@ def run(self, knn, threshold=0.6, fp_out=None, max_workers=None, 'sites'.format(i + 1, len(self.bias_meta))) else: - logger.debug('Running parallel calculation with {} workers.' - .format(max_workers)) + logger.debug( + 'Running parallel calculation with {} workers.'.format( + max_workers)) with ProcessPoolExecutor(max_workers=max_workers) as exe: futures = {} for bias_gid, bias_row in self.bias_meta.iterrows(): @@ -765,11 +810,18 @@ def run(self, knn, threshold=0.6, fp_out=None, max_workers=None, if np.mean(dist) < threshold: bias_data = self.get_bias_data(bias_gid) - future = exe.submit(self._run_single, bias_data, - self.base_fps, self.bias_feature, - self.base_dset, base_gid, - self.base_handler, daily_reduction, - self.bias_ti, self.decimals) + future = exe.submit( + self._run_single, + bias_data, + self.base_fps, + self.bias_feature, + self.base_dset, + base_gid, + self.base_handler, + daily_reduction, + self.bias_ti, + self.decimals, + ) futures[future] = raster_loc logger.debug('Finished launching futures.') @@ -802,14 +854,23 @@ class MonthlyLinearCorrection(LinearCorrection): """size of the time dimension, 12 is monthly bias correction""" @classmethod - def _run_single(cls, bias_data, base_fps, bias_feature, base_dset, - base_gid, base_handler, daily_reduction, bias_ti, + def _run_single(cls, + bias_data, + base_fps, + bias_feature, + base_dset, + base_gid, + base_handler, + daily_reduction, + bias_ti, decimals): """Find the nominal scalar + adder combination to bias correct data at a single site""" - base_data, base_ti = cls.get_base_data(base_fps, base_dset, - base_gid, base_handler, + base_data, base_ti = cls.get_base_data(base_fps, + base_dset, + base_gid, + base_handler, daily_reduction=daily_reduction, decimals=decimals) @@ -819,8 +880,7 @@ def _run_single(cls, bias_data, base_fps, bias_feature, base_dset, f'base_{base_dset}_mean': base_arr.copy(), f'base_{base_dset}_std': base_arr.copy(), f'{bias_feature}_scalar': base_arr.copy(), - f'{bias_feature}_adder': base_arr.copy(), - } + f'{bias_feature}_adder': base_arr.copy()} for month in range(1, 13): bias_mask = bias_ti.month == month @@ -900,7 +960,7 @@ def _run_skill_eval(cls, bias_data, base_data, bias_feature, base_dset): out = {} bias_mean = np.nanmean(bias_data) base_mean = np.nanmean(base_data) - out[f'{bias_feature}_bias'] = (bias_mean - base_mean) + out[f'{bias_feature}_bias'] = bias_mean - base_mean out[f'bias_{bias_feature}_mean'] = bias_mean out[f'bias_{bias_feature}_std'] = np.nanstd(bias_data) diff --git a/sup3r/pipeline/__init__.py b/sup3r/pipeline/__init__.py index ab67e730b..459e25cc3 100644 --- a/sup3r/pipeline/__init__.py +++ b/sup3r/pipeline/__init__.py @@ -2,5 +2,6 @@ """ Sup3r data pipeline architecture. """ +from gaps.legacy import Status + from .pipeline import Sup3rPipeline -from reV.pipeline.status import Status diff --git a/sup3r/pipeline/config.py b/sup3r/pipeline/config.py index 86025a267..7fe4345de 100644 --- a/sup3r/pipeline/config.py +++ b/sup3r/pipeline/config.py @@ -6,12 +6,14 @@ @author: bnb32 """ -from reV.config.base_config import BaseConfig as RevBaseConfig +import os +from typing import ClassVar + from reV.config.base_analysis_config import AnalysisConfig -from reV.config.pipeline import PipelineConfig -from reV.utilities.exceptions import ConfigError +from reV.config.base_config import BaseConfig as RevBaseConfig +from reV.utilities.exceptions import ConfigError, PipelineError -from sup3r import SUP3R_DIR, TEST_DATA_DIR, CONFIG_DIR +from sup3r import CONFIG_DIR, SUP3R_DIR, TEST_DATA_DIR class BaseConfig(RevBaseConfig): @@ -20,9 +22,11 @@ class BaseConfig(RevBaseConfig): REQUIREMENTS = () """Required keys for config""" - STR_REP = {'SUP3R_DIR': SUP3R_DIR, - 'CONFIG_DIR': CONFIG_DIR, - 'TEST_DATA_DIR': TEST_DATA_DIR} + STR_REP: ClassVar[dict] = { + 'SUP3R_DIR': SUP3R_DIR, + 'CONFIG_DIR': CONFIG_DIR, + 'TEST_DATA_DIR': TEST_DATA_DIR, + } """Mapping of config inputs (keys) to desired replacements (values) in addition to relative file paths as demarcated by ./ and ../""" @@ -40,13 +44,121 @@ class properties. perform_str_rep : bool Flag to perform string replacement for REVDIR, TESTDATADIR, and ./ """ - super().__init__(config, check_keys=check_keys, - perform_str_rep=perform_str_rep) + super().__init__( + config, check_keys=check_keys, perform_str_rep=perform_str_rep + ) -class Sup3rPipelineConfig(PipelineConfig): +class Sup3rPipelineConfig(AnalysisConfig): """Sup3r pipeline configuration based on reV pipeline""" + def __init__(self, config): + """ + Parameters + ---------- + config : str | dict + File path to config json (str), serialized json object (str), + or dictionary with pre-extracted config. + """ + + super().__init__(config, run_preflight=False) + self._check_pipeline() + self._parse_dirout() + self._check_dirout_status() + + def _check_pipeline(self): + """Check pipeline steps input. ConfigError if bad input.""" + + if 'pipeline' not in self: + raise ConfigError( + 'Could not find required key "pipeline" in the ' + 'pipeline config.' + ) + + if not isinstance(self.pipeline, list): + raise ConfigError( + 'Config arg "pipeline" must be a list of ' + '(command, f_config) pairs, but received "{}".'.format( + type(self.pipeline) + ) + ) + + for di in self.pipeline: + for f_config in di.values(): + if not os.path.exists(f_config): + raise ConfigError( + 'Pipeline step depends on non-existent ' + 'file: {}'.format(f_config) + ) + + def _check_dirout_status(self): + """Check unique status file in dirout.""" + + if os.path.exists(self.dirout): + for fname in os.listdir(self.dirout): + if fname.endswith( + '_status.json' + ) and fname != '{}_status.json'.format(self.name): + msg = ( + 'Cannot run pipeline "{}" in directory ' + '{}. Another pipeline appears to have ' + 'been run here with status json: {}'.format( + self.name, self.dirout, fname + ) + ) + raise PipelineError(msg) + + @property + def pipeline(self): + """Get the pipeline steps. + + Returns + ------- + pipeline : list + reV pipeline run steps. Should be a list of (command, config) + pairs. + """ + + return self['pipeline'] + + @property + def logging(self): + """Get logging kwargs for the pipeline. + + Returns + ------- + dict + """ + return self.get('logging', {"log_file": None, "log_level": "INFO"}) + + @property + def hardware(self): + """Get argument specifying which hardware the pipeline is being run on. + + Defaults to "eagle" (most common use of the reV pipeline) + + Returns + ------- + hardware : str + Name of hardware that this pipeline is being run on. + Defaults to "eagle". + """ + return self.get('hardware', 'eagle') + + @property + def status_file(self): + """Get status file path. + + Returns + ------- + _status_file : str + reV status file path. + """ + if self.dirout is None: + raise ConfigError('Pipeline has not yet been initialized.') + + return os.path.join(self.dirout, '{}_status.json'.format(self.name)) + # pylint: disable=W0201 def _parse_dirout(self): """Parse pipeline steps for common dirout and unique job names.""" @@ -55,21 +167,27 @@ def _parse_dirout(self): names = [] for di in self.pipeline: for f_config in di.values(): - config = AnalysisConfig(f_config, check_keys=False, - run_preflight=False) + config = AnalysisConfig( + f_config, check_keys=False, run_preflight=False + ) dirouts.append(config.dirout) if 'name' in config: names.append(config.name) if len(set(dirouts)) != 1: - raise ConfigError('Pipeline steps must have a common output ' - 'directory but received {} different ' - 'directories.'.format(len(set(dirouts)))) + raise ConfigError( + 'Pipeline steps must have a common output ' + 'directory but received {} different ' + 'directories.'.format(len(set(dirouts))) + ) else: self._dirout = dirouts[0] if len(set(names)) != len(names): - raise ConfigError('Pipeline steps must have a unique job names ' - 'directory but received {} duplicate names.' - .format(len(names) - len(set(names)))) + raise ConfigError( + 'Pipeline steps must have a unique job names ' + 'directory but received {} duplicate names.'.format( + len(names) - len(set(names)) + ) + ) diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index c41c09795..863e1c920 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -24,8 +24,8 @@ OutputHandlerH5, OutputHandlerNC, ) -from sup3r.preprocessing.data_handling import InputMixIn -from sup3r.preprocessing.exogenous_data_handling import ExogenousDataHandler +from sup3r.preprocessing.data_handling import ExogenousDataHandler +from sup3r.preprocessing.data_handling.base import InputMixIn from sup3r.utilities import ModuleName from sup3r.utilities.cli import BaseCLI from sup3r.utilities.execution import DistributedProcess @@ -43,8 +43,17 @@ class ForwardPassSlicer: """Get slices for sending data chunks through model.""" - def __init__(self, coarse_shape, time_steps, temporal_slice, chunk_shape, - s_enhancements, t_enhancements, spatial_pad, temporal_pad): + def __init__( + self, + coarse_shape, + time_steps, + temporal_slice, + chunk_shape, + s_enhancements, + t_enhancements, + spatial_pad, + temporal_pad, + ): """ Parameters ---------- @@ -214,8 +223,12 @@ def t_lr_pad_slices(self): """ if self._t_lr_pad_slices is None: self._t_lr_pad_slices = self.get_padded_slices( - self.t_lr_slices, self.time_steps, 1, - self.temporal_pad, self.temporal_slice.step) + self.t_lr_slices, + self.time_steps, + 1, + self.temporal_pad, + self.temporal_slice.step, + ) return self._t_lr_pad_slices @property @@ -255,8 +268,10 @@ def t_hr_crop_slices(self): # don't use self.get_cropped_slices() here because temporal padding # gets weird at beginning and end of timeseries and the temporal # axis should always be evenly chunked. - self._t_hr_crop_slices = [slice(hr_crop_start, hr_crop_stop) - for _ in range(len(self.t_lr_slices))] + self._t_hr_crop_slices = [ + slice(hr_crop_start, hr_crop_stop) + for _ in range(len(self.t_lr_slices)) + ] return self._t_hr_crop_slices @@ -311,8 +326,12 @@ def s_lr_crop_slices(self): self.s2_lr_pad_slices, 1) for i, _ in enumerate(self.s1_lr_slices): for j, _ in enumerate(self.s2_lr_slices): - lr_crop_slice = (s1_crop_slices[i], s2_crop_slices[j], - slice(None), slice(None)) + lr_crop_slice = ( + s1_crop_slices[i], + s2_crop_slices[j], + slice(None), + slice(None), + ) self._s_lr_crop_slices.append(lr_crop_slice) return self._s_lr_crop_slices @@ -335,10 +354,14 @@ def s_hr_crop_slices(self): if self._s_hr_crop_slices is None: self._s_hr_crop_slices = [] - s1_hr_crop_slices = [slice(hr_crop_start, hr_crop_stop) - for _ in range(len(self.s1_lr_slices))] - s2_hr_crop_slices = [slice(hr_crop_start, hr_crop_stop) - for _ in range(len(self.s2_lr_slices))] + s1_hr_crop_slices = [ + slice(hr_crop_start, hr_crop_stop) + for _ in range(len(self.s1_lr_slices)) + ] + s2_hr_crop_slices = [ + slice(hr_crop_start, hr_crop_stop) + for _ in range(len(self.s2_lr_slices)) + ] for _, s1 in enumerate(s1_hr_crop_slices): for _, s2 in enumerate(s2_hr_crop_slices): @@ -375,8 +398,11 @@ def s1_lr_pad_slices(self): spatial dimension""" if self._s1_lr_pad_slices is None: self._s1_lr_pad_slices = self.get_padded_slices( - self.s1_lr_slices, self.grid_shape[0], 1, - padding=self.spatial_pad) + self.s1_lr_slices, + self.grid_shape[0], + 1, + padding=self.spatial_pad, + ) return self._s1_lr_pad_slices @property @@ -385,8 +411,11 @@ def s2_lr_pad_slices(self): spatial dimension""" if self._s2_lr_pad_slices is None: self._s2_lr_pad_slices = self.get_padded_slices( - self.s2_lr_slices, self.grid_shape[1], 1, - padding=self.spatial_pad) + self.s2_lr_slices, + self.grid_shape[1], + 1, + padding=self.spatial_pad, + ) return self._s2_lr_pad_slices @property @@ -394,7 +423,8 @@ def s1_lr_slices(self): """List of low resolution spatial slices for first spatial dimension considering padding on all sides of the spatial raster.""" ind = slice(0, self.grid_shape[0]) - slices = get_chunk_slices(self.grid_shape[0], self.chunk_shape[0], + slices = get_chunk_slices(self.grid_shape[0], + self.chunk_shape[0], index_slice=ind) return slices @@ -403,7 +433,8 @@ def s2_lr_slices(self): """List of low resolution spatial slices for second spatial dimension considering padding on all sides of the spatial raster.""" ind = slice(0, self.grid_shape[1]) - slices = get_chunk_slices(self.grid_shape[1], self.chunk_shape[1], + slices = get_chunk_slices(self.grid_shape[1], + self.chunk_shape[1], index_slice=ind) return slices @@ -415,8 +446,9 @@ def t_lr_slices(self): n_chunks = int(np.ceil(n_chunks)) ti_slices = self.dummy_time_index[self.temporal_slice] ti_slices = np.array_split(ti_slices, n_chunks) - ti_slices = [slice(c[0], c[-1] + 1, self.temporal_slice.step) - for c in ti_slices] + ti_slices = [ + slice(c[0], c[-1] + 1, self.temporal_slice.step) for c in ti_slices + ] return ti_slices @staticmethod @@ -567,18 +599,24 @@ class ForwardPassStrategy(InputMixIn, DistributedProcess): crop generator output to stich the chunks back togerther. """ - def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, - spatial_pad, temporal_pad, - model_class='Sup3rGan', - out_pattern=None, - input_handler=None, - input_handler_kwargs=None, - incremental=True, - worker_kwargs=None, - exo_kwargs=None, - bias_correct_method=None, - bias_correct_kwargs=None, - max_nodes=None): + def __init__( + self, + file_paths, + model_kwargs, + fwp_chunk_shape, + spatial_pad, + temporal_pad, + model_class='Sup3rGan', + out_pattern=None, + input_handler=None, + input_handler_kwargs=None, + incremental=True, + worker_kwargs=None, + exo_kwargs=None, + bias_correct_method=None, + bias_correct_kwargs=None, + max_nodes=None, + ): """Use these inputs to initialize data handlers on different nodes and to define the size of the data chunks that will be passed through the generator. @@ -683,13 +721,16 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, grid_shape = self._input_handler_kwargs.get('shape', None) raster_file = self._input_handler_kwargs.get('raster_file', None) raster_index = self._input_handler_kwargs.get('raster_index', None) - temporal_slice = self._input_handler_kwargs.get('temporal_slice', - slice(None, None, 1)) - InputMixIn.__init__(self, target=target, - shape=grid_shape, - raster_file=raster_file, - raster_index=raster_index, - temporal_slice=temporal_slice) + temporal_slice = self._input_handler_kwargs.get( + 'temporal_slice', slice(None, None, 1)) + InputMixIn.__init__( + self, + target=target, + shape=grid_shape, + raster_file=raster_file, + raster_index=raster_index, + temporal_slice=temporal_slice, + ) self.file_paths = file_paths self.model_kwargs = model_kwargs @@ -713,8 +754,8 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, self._single_ts_files = self._input_handler_kwargs.get( 'single_ts_files', None) - self.cache_pattern = self._input_handler_kwargs.get('cache_pattern', - None) + self.cache_pattern = self._input_handler_kwargs.get( + 'cache_pattern', None) self.max_workers = self.worker_kwargs.get('max_workers', None) self.output_workers = self.worker_kwargs.get('output_workers', None) self.pass_workers = self.worker_kwargs.get('pass_workers', None) @@ -741,18 +782,23 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, self.t_enhance = np.product(self.t_enhancements) self.output_features = model.output_features - self.fwp_slicer = ForwardPassSlicer(self.grid_shape, - self.raw_tsteps, - self.temporal_slice, - self.fwp_chunk_shape, - self.s_enhancements, - self.t_enhancements, - self.spatial_pad, - self.temporal_pad) - - DistributedProcess.__init__(self, max_nodes=max_nodes, - max_chunks=self.fwp_slicer.n_chunks, - incremental=self.incremental) + self.fwp_slicer = ForwardPassSlicer( + self.grid_shape, + self.raw_tsteps, + self.temporal_slice, + self.fwp_chunk_shape, + self.s_enhancements, + self.t_enhancements, + self.spatial_pad, + self.temporal_pad, + ) + + DistributedProcess.__init__( + self, + max_nodes=max_nodes, + max_chunks=self.fwp_slicer.n_chunks, + incremental=self.incremental, + ) self.preflight() @@ -776,29 +822,41 @@ def preflight(self): f'({self.fwp_chunk_shape[2] + 2 * self.temporal_pad}) ' f'larger than the full temporal domain ({self.raw_tsteps}). ' 'Should just run without temporal chunking. ') - if (self.fwp_chunk_shape[2] + 2 * self.temporal_pad - >= self.raw_tsteps): + if self.fwp_chunk_shape[2] + 2 * self.temporal_pad >= self.raw_tsteps: logger.warning(msg) warnings.warn(msg) - hr_data_shape = (self.grid_shape[0] * self.s_enhance, - self.grid_shape[1] * self.s_enhance) + hr_data_shape = ( + self.grid_shape[0] * self.s_enhance, + self.grid_shape[1] * self.s_enhance, + ) self.gids = np.arange(np.product(hr_data_shape)) self.gids = self.gids.reshape(hr_data_shape) out = self.fwp_slicer.get_spatial_slices() self.lr_slices, self.lr_pad_slices, self.hr_slices = out + def _get_spatial_chunk_index(self, chunk_index): + """Get the spatial index for the given chunk index""" + return chunk_index % self.fwp_slicer.n_spatial_chunks + + def _get_temporal_chunk_index(self, chunk_index): + """Get the temporal index for the given chunk index""" + return chunk_index // self.fwp_slicer.n_spatial_chunks + # pylint: disable=E1102 @property def init_handler(self): """Get initial input handler used for extracting handler features and low res grid""" if self._init_handler is None: - out = self.input_handler_class(self.file_paths[0], [], - target=self.target, - shape=self.grid_shape, - worker_kwargs=dict(ti_workers=1)) + out = self.input_handler_class( + self.file_paths[0], + [], + target=self.target, + shape=self.grid_shape, + worker_kwargs=dict(ti_workers=1), + ) self._init_handler = out return self._init_handler @@ -828,8 +886,8 @@ def hr_lat_lon(self): if self._hr_lat_lon is None: logger.info('Getting high-resolution grid for full output domain.') lr_lat_lon = self.lr_lat_lon.copy() - self._hr_lat_lon = OutputHandler.get_lat_lon(lr_lat_lon, - self.gids.shape) + self._hr_lat_lon = OutputHandler.get_lat_lon( + lr_lat_lon, self.gids.shape) return self._hr_lat_lon def get_full_domain(self, file_paths): @@ -838,7 +896,8 @@ def get_full_domain(self, file_paths): def get_lat_lon(self, file_paths, raster_index, invert_lat=False): """Get lat/lon grid for requested target and shape""" - return self.input_handler_class.get_lat_lon(file_paths, raster_index, + return self.input_handler_class.get_lat_lon(file_paths, + raster_index, invert_lat=invert_lat) def get_time_index(self, file_paths, max_workers=None, **kwargs): @@ -940,8 +999,8 @@ def max_nodes(self): """Get the maximum number of nodes that this strategy should distribute work to, equal to either the specified max number of nodes or total number of temporal chunks""" - self._max_nodes = (self._max_nodes if self._max_nodes is not None - else self.fwp_slicer.n_temporal_chunks) + self._max_nodes = (self._max_nodes if self._max_nodes is not None else + self.fwp_slicer.n_temporal_chunks) return self._max_nodes @staticmethod @@ -1050,10 +1109,12 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.features = [f for f in self.features if f not in exo_features] self.exogenous_handler = ExogenousDataHandler(**exo_kwargs) self.exogenous_data = self.exogenous_handler.data - shapes = [None if d is None else d.shape - for d in self.exogenous_data] - logger.info('Got exogenous_data of length {} with shapes: {}' - .format(len(self.exogenous_data), shapes)) + shapes = [ + None if d is None else d.shape for d in self.exogenous_data + ] + logger.info( + 'Got exogenous_data of length {} with shapes: {}'.format( + len(self.exogenous_data), shapes)) self.input_handler_class = strategy.input_handler_class @@ -1106,7 +1167,8 @@ def update_input_handler_kwargs(self, strategy): cache_pattern=self.cache_pattern, single_ts_files=self.single_ts_files, handle_features=strategy.handle_features, - val_split=0.0) + val_split=0.0, + ) input_handler_kwargs.update(fwp_input_handler_kwargs) return input_handler_kwargs @@ -1141,8 +1203,7 @@ def lr_times(self): @property def lr_lat_lon(self): """Get low resolution lat lon for current chunk""" - return self.strategy.lr_lat_lon[self.lr_slice[0], - self.lr_slice[1]] + return self.strategy.lr_lat_lon[self.lr_slice[0], self.lr_slice[1]] @property def hr_lat_lon(self): @@ -1155,19 +1216,34 @@ def hr_times(self): return self.output_handler_class.get_times( self.lr_times, self.t_enhance * len(self.lr_times)) + @property + def chunk_specific_meta(self): + """Meta with chunk specific info. To be included in chunk output file + global attributes.""" + meta_data = { + "node_index": self.node_index, + 'creation_date': dt.now().strftime("%d/%m/%Y %H:%M:%S"), + 'fwp_chunk_shape': self.strategy.fwp_chunk_shape, + 'spatial_pad': self.strategy.spatial_pad, + 'temporal_pad': self.strategy.temporal_pad, + } + return meta_data + @property def meta(self): """Meta data dictionary for the forward pass run (to write to output files).""" - meta_data = {'gan_meta': self.model.meta, - 'model_kwargs': self.model_kwargs, - 'model_class': self.model_class, - 'spatial_enhance': int(self.s_enhance), - 'temporal_enhance': int(self.t_enhance), - 'input_files': self.file_paths, - 'input_features': self.features, - 'output_features': self.output_features, - } + meta_data = { + 'chunk_meta': self.chunk_specific_meta, + 'gan_meta': self.model.meta, + 'model_kwargs': self.model_kwargs, + 'model_class': self.model_class, + 'spatial_enhance': int(self.s_enhance), + 'temporal_enhance': int(self.t_enhance), + 'input_files': self.file_paths, + 'input_features': self.features, + 'output_features': self.output_features, + } return meta_data @property @@ -1227,12 +1303,12 @@ def chunks(self): @property def spatial_chunk_index(self): """Spatial index for the current chunk going through forward pass""" - return self.chunk_index % self.strategy.fwp_slicer.n_spatial_chunks + return self.strategy._get_spatial_chunk_index(self.chunk_index) @property def temporal_chunk_index(self): """Temporal index for the current chunk going through forward pass""" - return self.chunk_index // self.strategy.fwp_slicer.n_spatial_chunks + return self.strategy._get_temporal_chunk_index(self.chunk_index) @property def out_file(self): @@ -1280,9 +1356,11 @@ def lr_crop_slice(self): @property def chunk_shape(self): """Get shape for the current padded spatiotemporal chunk""" - return (self.lr_pad_slice[0].stop - self.lr_pad_slice[0].start, - self.lr_pad_slice[1].stop - self.lr_pad_slice[1].start, - self.ti_pad_slice.stop - self.ti_pad_slice.start) + return ( + self.lr_pad_slice[0].stop - self.lr_pad_slice[0].start, + self.lr_pad_slice[1].stop - self.lr_pad_slice[1].start, + self.ti_pad_slice.stop - self.ti_pad_slice.start, + ) @property def cache_pattern(self): @@ -1309,8 +1387,8 @@ def raster_file(self): if '{spatial_chunk_index}' not in raster_file: raster_file = raster_file.replace( '.txt', '_{spatial_chunk_index}.txt') - raster_file = raster_file.replace( - '{spatial_chunk_index}', str(self.spatial_chunk_index)) + raster_file = raster_file.replace('{spatial_chunk_index}', + str(self.spatial_chunk_index)) return raster_file @property @@ -1326,32 +1404,36 @@ def pad_width(self): """ ti_start = self.ti_slice.start or 0 ti_stop = self.ti_slice.stop or self.strategy.raw_tsteps - pad_t_start = int(np.maximum(0, (self.strategy.temporal_pad - - ti_start))) - pad_t_end = int(np.maximum(0, (self.strategy.temporal_pad - + ti_stop - self.strategy.raw_tsteps))) + pad_t_start = int( + np.maximum(0, (self.strategy.temporal_pad - ti_start))) + pad_t_end = (self.strategy.temporal_pad + ti_stop + - self.strategy.raw_tsteps) + pad_t_end = int(np.maximum(0, pad_t_end)) s1_start = self.lr_slice[0].start or 0 s1_stop = self.lr_slice[0].stop or self.strategy.grid_shape[0] - pad_s1_start = int(np.maximum(0, (self.strategy.spatial_pad - - s1_start))) - pad_s1_end = int(np.maximum(0, (self.strategy.spatial_pad - + s1_stop - - self.strategy.grid_shape[0]))) + pad_s1_start = int( + np.maximum(0, (self.strategy.spatial_pad - s1_start))) + pad_s1_end = (self.strategy.spatial_pad + s1_stop + - self.strategy.grid_shape[0]) + pad_s1_end = int(np.maximum(0, pad_s1_end)) s2_start = self.lr_slice[1].start or 0 s2_stop = self.lr_slice[1].stop or self.strategy.grid_shape[1] - pad_s2_start = int(np.maximum(0, (self.strategy.spatial_pad - - s2_start))) - pad_s2_end = int(np.maximum(0, (self.strategy.spatial_pad - + s2_stop - - self.strategy.grid_shape[1]))) + pad_s2_start = int( + np.maximum(0, (self.strategy.spatial_pad - s2_start))) + pad_s2_end = (self.strategy.spatial_pad + s2_stop + - self.strategy.grid_shape[1]) + pad_s2_end = int(np.maximum(0, pad_s2_end)) return ((pad_s1_start, pad_s1_end), (pad_s2_start, pad_s2_end), (pad_t_start, pad_t_end)) @staticmethod - def pad_source_data(input_data, pad_width, exo_data, - exo_s_enhancements, mode='reflect'): + def pad_source_data(input_data, + pad_width, + exo_data, + exo_s_enhancements, + mode='reflect'): """Pad the edges of the source data from the data handler. Parameters @@ -1390,21 +1472,23 @@ def pad_source_data(input_data, pad_width, exo_data, out = np.pad(input_data, (*pad_width, (0, 0)), mode=mode) logger.info('Padded input data shape from {} to {} using mode "{}" ' - 'with padding argument: {}' - .format(input_data.shape, out.shape, mode, pad_width)) + 'with padding argument: {}'.format(input_data.shape, + out.shape, mode, + pad_width)) if exo_data is not None: for i, i_exo_data in enumerate(exo_data): if i_exo_data is not None: total_s_enhance = exo_s_enhancements[:i + 1] - total_s_enhance = [s for s in total_s_enhance - if s is not None] + total_s_enhance = [ + s for s in total_s_enhance if s is not None + ] total_s_enhance = np.product(total_s_enhance) exo_pad_width = ((total_s_enhance * pad_width[0][0], total_s_enhance * pad_width[0][1]), (total_s_enhance * pad_width[1][0], - total_s_enhance * pad_width[1][1]), - (0, 0)) + total_s_enhance * pad_width[1][1]), (0, + 0)) exo_data[i] = np.pad(i_exo_data, exo_pad_width, mode=mode) return out, exo_data @@ -1443,8 +1527,8 @@ def bias_correct_source_data(self, data, lat_lon): feature_kwargs['time_index'] = self.data_handler.time_index logger.debug('Bias correcting feature "{}" at axis index {} ' - 'using function: {} with kwargs: {}' - .format(feature, idf, method, feature_kwargs)) + 'using function: {} with kwargs: {}'.format( + feature, idf, method, feature_kwargs)) data[..., idf] = method(data[..., idf], lat_lon, **feature_kwargs) @@ -1474,12 +1558,16 @@ def _prep_exogenous_input(self, chunk_shape): arr = np.expand_dims(arr, axis=2) arr = np.repeat(arr, chunk_shape[2], axis=2) - target_shape = (arr.shape[0], arr.shape[1], chunk_shape[2], - arr.shape[-1]) + target_shape = ( + arr.shape[0], + arr.shape[1], + chunk_shape[2], + arr.shape[-1], + ) msg = ('Target shape for exogenous data in forward pass ' 'chunk was {}, but something went wrong and i ' - 'resized original data shape from {} to {}' - .format(target_shape, og_shape, arr.shape)) + 'resized original data shape from {} to {}'.format( + target_shape, og_shape, arr.shape)) assert arr.shape == target_shape, msg exo_data.append(arr) @@ -1487,10 +1575,17 @@ def _prep_exogenous_input(self, chunk_shape): return exo_data @classmethod - def _run_generator(cls, data_chunk, hr_crop_slices, - model=None, model_kwargs=None, model_class=None, - s_enhance=None, t_enhance=None, - exo_data=None): + def _run_generator( + cls, + data_chunk, + hr_crop_slices, + model=None, + model_kwargs=None, + model_class=None, + s_enhance=None, + t_enhance=None, + exo_data=None, + ): """Run forward pass of the generator on smallest data chunk. Each chunk has a maximum shape given by self.strategy.fwp_chunk_shape. @@ -1547,8 +1642,8 @@ def _run_generator(cls, data_chunk, hr_crop_slices, try: hi_res = model.generate(data_chunk, exogenous_data=exo_data) except Exception as e: - msg = ('Forward pass failed on chunk with shape {}.' - .format(data_chunk.shape)) + msg = 'Forward pass failed on chunk with shape {}.'.format( + data_chunk.shape) logger.exception(msg) raise RuntimeError(msg) from e @@ -1558,16 +1653,16 @@ def _run_generator(cls, data_chunk, hr_crop_slices, if (s_enhance is not None and hi_res.shape[1] != s_enhance * data_chunk.shape[i_lr_s]): msg = ('The stated spatial enhancement of {}x did not match ' - 'the low res / high res shapes of {} -> {}' - .format(s_enhance, data_chunk.shape, hi_res.shape)) + 'the low res / high res shapes of {} -> {}'.format( + s_enhance, data_chunk.shape, hi_res.shape)) logger.error(msg) raise RuntimeError(msg) if (t_enhance is not None and hi_res.shape[3] != t_enhance * data_chunk.shape[i_lr_t]): msg = ('The stated temporal enhancement of {}x did not match ' - 'the low res / high res shapes of {} -> {}' - .format(t_enhance, data_chunk.shape, hi_res.shape)) + 'the low res / high res shapes of {} -> {}'.format( + t_enhance, data_chunk.shape, hi_res.shape)) logger.error(msg) raise RuntimeError(msg) @@ -1647,7 +1742,7 @@ def get_node_cmd(cls, config): import_str += 'import os;\n' import_str += 'os.environ["CUDA_VISIBLE_DEVICES"] = "-1";\n' import_str += 'import time;\n' - import_str += 'from reV.pipeline.status import Status;\n' + import_str += 'from sup3r.pipeline import Status;\n' import_str += 'from rex import init_logger;\n' import_str += ('from sup3r.pipeline.forward_pass ' f'import ForwardPassStrategy, {cls.__name__};\n') @@ -1657,7 +1752,7 @@ def get_node_cmd(cls, config): node_index = config['node_index'] log_file = config.get('log_file', None) log_level = config.get('log_level', 'INFO') - log_arg_str = (f'"sup3r", log_level="{log_level}"') + log_arg_str = f'"sup3r", log_level="{log_level}"' if log_file is not None: log_arg_str += f', log_file="{log_file}"' @@ -1669,7 +1764,7 @@ def get_node_cmd(cls, config): "t_elap = time.time() - t0;\n") cmd = BaseCLI.add_status_cmd(config, ModuleName.FORWARD_PASS, cmd) - cmd += (";\'\n") + cmd += ";\'\n" return cmd.replace('\\', '/') @@ -1683,7 +1778,7 @@ def _constant_output_check(self, out_data): Forward pass output corresponding to the given chunk index """ for i, f in enumerate(self.output_features): - msg = (f'All spatiotemporal values are the same for {f} output!') + msg = f'All spatiotemporal values are the same for {f} output!' if np.all(out_data[0, 0, 0, i] == out_data[..., i]): self.strategy.failed_chunks = True logger.error(msg) @@ -1770,8 +1865,11 @@ def _run_serial(cls, strategy, node_index): 'serial.') for i, chunk_index in enumerate(strategy.node_chunks[node_index]): now = dt.now() - cls._single_proc_run(strategy=strategy, node_index=node_index, - chunk_index=chunk_index) + cls._single_proc_run( + strategy=strategy, + node_index=node_index, + chunk_index=chunk_index, + ) mem = psutil.virtual_memory() logger.info('Finished forward pass on chunk_index=' f'{chunk_index} in {dt.now() - now}. {i + 1} of ' @@ -1808,12 +1906,16 @@ def _run_parallel(cls, strategy, node_index): with SpawnProcessPool(**pool_kws) as exe: now = dt.now() for _i, chunk_index in enumerate(strategy.node_chunks[node_index]): - fut = exe.submit(cls._single_proc_run, - strategy=strategy, - node_index=node_index, - chunk_index=chunk_index) - futures[fut] = {'chunk_index': chunk_index, - 'start_time': dt.now()} + fut = exe.submit( + cls._single_proc_run, + strategy=strategy, + node_index=node_index, + chunk_index=chunk_index, + ) + futures[fut] = { + 'chunk_index': chunk_index, + 'start_time': dt.now(), + } logger.info(f'Started {len(futures)} forward pass runs in ' f'{dt.now() - now}.') @@ -1855,18 +1957,28 @@ def run_chunk(self): exo_data = self._prep_exogenous_input(data_chunk.shape) self.output_data = self._run_generator( - data_chunk, hr_crop_slices=self.hr_crop_slice, model=self.model, - model_kwargs=self.model_kwargs, model_class=self.model_class, - s_enhance=self.s_enhance, t_enhance=self.t_enhance, - exo_data=exo_data) + data_chunk, + hr_crop_slices=self.hr_crop_slice, + model=self.model, + model_kwargs=self.model_kwargs, + model_class=self.model_class, + s_enhance=self.s_enhance, + t_enhance=self.t_enhance, + exo_data=exo_data, + ) self._constant_output_check(self.output_data) if self.out_file is not None: logger.info(f'Saving forward pass output to {self.out_file}.') self.output_handler_class._write_output( - data=self.output_data, features=self.model.output_features, - lat_lon=self.hr_lat_lon, times=self.hr_times, - out_file=self.out_file, meta_data=self.meta, - max_workers=self.output_workers, gids=self.gids) + data=self.output_data, + features=self.model.output_features, + lat_lon=self.hr_lat_lon, + times=self.hr_times, + out_file=self.out_file, + meta_data=self.meta, + max_workers=self.output_workers, + gids=self.gids, + ) return self.output_data diff --git a/sup3r/pipeline/pipeline.py b/sup3r/pipeline/pipeline.py index b84890ae6..8ebed5864 100644 --- a/sup3r/pipeline/pipeline.py +++ b/sup3r/pipeline/pipeline.py @@ -1,7 +1,8 @@ """Sup3r data pipeline architecture.""" import logging +from typing import ClassVar -from reV.pipeline.pipeline import Pipeline +from gaps.legacy import Pipeline from rex.utilities.loggers import init_logger from sup3r.pipeline.config import Sup3rPipelineConfig @@ -15,10 +16,12 @@ class Sup3rPipeline(Pipeline): CMD_BASE = 'python -m sup3r.cli -c {fp_config} {command}' COMMANDS = ModuleName.all_names() - RETURN_CODES = {0: 'successful', - 1: 'running', - 2: 'failed', - 3: 'complete'} + RETURN_CODES: ClassVar[dict] = { + 0: 'successful', + 1: 'running', + 2: 'failed', + 3: 'complete', + } def __init__(self, pipeline, monitor=True, verbose=False): """Parameters @@ -39,4 +42,4 @@ def __init__(self, pipeline, monitor=True, verbose=False): # init logger for pipeline module if requested in input config if 'logging' in self._config: init_logger('sup3r.pipeline', **self._config.logging) - init_logger('reV.pipeline', **self._config.logging) + init_logger('gaps.legacy', **self._config.logging) diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index 4f27ee1dd..a4e9a1fe3 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -57,7 +57,7 @@ def get_node_cmd(cls, config): 'import Collector;\n' 'from rex import init_logger;\n' 'import time;\n' - 'from reV.pipeline.status import Status;\n' + 'from sup3r.pipeline import Status;\n' ) dc_fun_str = get_fun_call_str(cls.collect, config) @@ -132,7 +132,6 @@ def get_slices( raise RuntimeError(msg) row_slice = slice(np.min(row_loc), np.max(row_loc) + 1) - col_slice = slice(np.min(col_loc), np.max(col_loc) + 1) msg = ( f'row_slice={row_slice} conflict with row_indices={row_loc}. ' @@ -140,15 +139,6 @@ def get_slices( ) assert (row_slice.stop - row_slice.start) == len(row_loc), msg - msg = ( - f'col_slice={col_slice} conflict with col_indices={col_loc}. ' - 'Indices do not seem to be increasing and/or contiguous.' - ) - check = (col_slice.stop - col_slice.start) == len(col_loc) - if not check: - logger.warning(msg) - warn(msg) - return row_slice, col_loc def get_coordinate_indices(self, target_meta, full_meta, threshold=1e-4): diff --git a/sup3r/postprocessing/file_handling.py b/sup3r/postprocessing/file_handling.py index 79f4bc385..f86339f16 100644 --- a/sup3r/postprocessing/file_handling.py +++ b/sup3r/postprocessing/file_handling.py @@ -2,28 +2,30 @@ author : @bbenton """ +import json +import logging +import os +import re from abc import abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime as dt +from warnings import warn + import numpy as np -import xarray as xr import pandas as pd -import logging +import xarray as xr +from rex.outputs import Outputs as BaseRexOutputs from scipy.interpolate import griddata -import re -from datetime import datetime as dt -import json -import os -from warnings import warn -from sup3r.version import __version__ -from sup3r.utilities import VERSION_RECORD -from sup3r.utilities.utilities import (invert_uv, - get_time_dim_name, - estimate_max_workers, - pd_date_range) from sup3r.preprocessing.feature_handling import Feature - -from rex.outputs import Outputs as BaseRexOutputs +from sup3r.utilities import VERSION_RECORD +from sup3r.utilities.utilities import ( + estimate_max_workers, + get_time_dim_name, + invert_uv, + pd_date_range, +) +from sup3r.version import __version__ logger = logging.getLogger(__name__) @@ -235,6 +237,8 @@ def write_data(cls, out_file, dsets, time_index, data_list, meta, Pre-existing H5 file output path dsets : list list of datasets to write to out_file + time_index : pd.DatetimeIndex() + Pandas datetime index to use for file time_index. data_list : list List of np.ndarray objects to write to out_file meta : pd.DataFrame @@ -260,7 +264,7 @@ def write_data(cls, out_file, dsets, time_index, data_list, meta, os.replace(tmp_file, out_file) msg = ('Saved output of size ' - f'{(len(data_list),) + data_list[0].shape} to: {out_file}') + f'{(len(data_list), *data_list[0].shape)} to: {out_file}') logger.info(msg) diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index 5ed05fc96..c89196a62 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -12,7 +12,9 @@ from rex.utilities import log_mem from scipy.ndimage.filters import gaussian_filter -from sup3r.preprocessing.data_handling import DataHandlerDCforH5 +from sup3r.preprocessing.data_handling.h5_data_handling import ( + DataHandlerDCforH5, +) from sup3r.utilities.utilities import ( estimate_max_workers, nn_fill_array, @@ -221,13 +223,13 @@ def __init__( handler_shapes = np.array([d.sample_shape for d in data_handlers]) assert np.all(handler_shapes[0] == handler_shapes) - self.handlers = data_handlers + self.s_enhance = s_enhance + self.t_enhance = t_enhance + self.data_handlers = data_handlers self.batch_size = batch_size self.sample_shape = handler_shapes[0] self.val_indices = self._get_val_indices() self.max = np.ceil(len(self.val_indices) / (batch_size)) - self.s_enhance = s_enhance - self.t_enhance = t_enhance self._remaining_observations = len(self.val_indices) self.temporal_coarsening_method = temporal_coarsening_method self._i = 0 @@ -235,6 +237,7 @@ def __init__( self.output_features = output_features self.smoothing = smoothing self.smoothing_ignore = smoothing_ignore + self.current_batch_indices = [] def _get_val_indices(self): """List of dicts to index each validation data observation across all @@ -249,7 +252,7 @@ def _get_val_indices(self): """ val_indices = [] - for i, h in enumerate(self.handlers): + for i, h in enumerate(self.data_handlers): if h.val_data is not None: for _ in range(h.val_data.shape[2]): spatial_slice = uniform_box_sampler( @@ -286,13 +289,13 @@ def shape(self): dimension """ time_steps = 0 - for h in self.handlers: + for h in self.data_handlers: time_steps += h.val_data.shape[2] return ( - self.handlers[0].val_data.shape[0], - self.handlers[0].val_data.shape[1], + self.data_handlers[0].val_data.shape[0], + self.data_handlers[0].val_data.shape[1], time_steps, - self.handlers[0].val_data.shape[3], + self.data_handlers[0].val_data.shape[3], ) def __iter__(self): @@ -343,35 +346,30 @@ def __next__(self): validation data batch with low and high res data each with n_observations = batch_size """ + self.current_batch_indices = [] if self._remaining_observations > 0: if self._remaining_observations > self.batch_size: - high_res = np.zeros( - ( - self.batch_size, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.handlers[0].shape[-1], - ), - dtype=np.float32, - ) + n_obs = self.batch_size else: - high_res = np.zeros( - ( - self._remaining_observations, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.handlers[0].shape[-1], - ), - dtype=np.float32, - ) + n_obs = self._remaining_observations + + high_res = np.zeros( + ( + n_obs, + self.sample_shape[0], + self.sample_shape[1], + self.sample_shape[2], + self.data_handlers[0].shape[-1], + ), + dtype=np.float32, + ) for i in range(high_res.shape[0]): val_index = self.val_indices[self._i + i] - high_res[i, ...] = self.handlers[ + high_res[i, ...] = self.data_handlers[ val_index['handler_index'] ].val_data[val_index['tuple_index']] self._remaining_observations -= 1 + self.current_batch_indices.append(val_index['handler_index']) if self.sample_shape[2] == 1: high_res = high_res[..., 0, :] @@ -663,7 +661,8 @@ def parallel_load(self): max_workers = self.load_workers if max_workers == 1: for d in self.data_handlers: - d.load_cached_data() + if d.data is None: + d.load_cached_data() else: with ThreadPoolExecutor(max_workers=max_workers) as exe: futures = {} @@ -1296,8 +1295,8 @@ def _get_val_indices(self): val_indices = {} for t in range(self.N_TIME_BINS): val_indices[t] = [] - h_idx = np.random.choice(np.arange(len(self.handlers))) - h = self.handlers[h_idx] + h_idx = np.random.choice(np.arange(len(self.data_handlers))) + h = self.data_handlers[h_idx] for _ in range(self.batch_size): spatial_slice = uniform_box_sampler( h.data, self.sample_shape[:2] @@ -1319,8 +1318,8 @@ def _get_val_indices(self): ) for s in range(self.N_SPACE_BINS): val_indices[s + self.N_TIME_BINS] = [] - h_idx = np.random.choice(np.arange(len(self.handlers))) - h = self.handlers[h_idx] + h_idx = np.random.choice(np.arange(len(self.data_handlers))) + h = self.data_handlers[h_idx] for _ in range(self.batch_size): weights = np.zeros(self.N_SPACE_BINS) weights[s] = 1 @@ -1350,15 +1349,15 @@ def __next__(self): self.sample_shape[0], self.sample_shape[1], self.sample_shape[2], - self.handlers[0].shape[-1], + self.data_handlers[0].shape[-1], ), dtype=np.float32, ) val_indices = self.val_indices[self._i] for i, idx in enumerate(val_indices): - high_res[i, ...] = self.handlers[idx['handler_index']].data[ - idx['tuple_index'] - ] + high_res[i, ...] = self.data_handlers[ + idx['handler_index'] + ].data[idx['tuple_index']] batch = self.BATCH_CLASS.get_coarse_batch( high_res, @@ -1394,15 +1393,15 @@ def __next__(self): self.batch_size, self.sample_shape[0], self.sample_shape[1], - self.handlers[0].shape[-1], + self.data_handlers[0].shape[-1], ), dtype=np.float32, ) val_indices = self.val_indices[self._i] for i, idx in enumerate(val_indices): - high_res[i, ...] = self.handlers[idx['handler_index']].data[ - idx['tuple_index'] - ][..., 0, :] + high_res[i, ...] = self.data_handlers[ + idx['handler_index'] + ].data[idx['tuple_index']][..., 0, :] batch = self.BATCH_CLASS.get_coarse_batch( high_res, diff --git a/sup3r/preprocessing/conditional_moment_batch_handling.py b/sup3r/preprocessing/conditional_moment_batch_handling.py index cdabf8780..a1418b101 100644 --- a/sup3r/preprocessing/conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/conditional_moment_batch_handling.py @@ -3,19 +3,23 @@ Sup3r conditional moment batch_handling module. """ import logging -import numpy as np from datetime import datetime as dt +import numpy as np from rex.utilities import log_mem -from sup3r.utilities.utilities import (spatial_coarsening, - temporal_coarsening, - spatial_simple_enhancing, - temporal_simple_enhancing, - smooth_data) -from sup3r.preprocessing.batch_handling import (Batch, - ValidationData, - BatchHandler) +from sup3r.preprocessing.batch_handling import ( + Batch, + BatchHandler, + ValidationData, +) +from sup3r.utilities.utilities import ( + smooth_data, + spatial_coarsening, + spatial_simple_enhancing, + temporal_coarsening, + temporal_simple_enhancing, +) np.random.seed(42) @@ -55,11 +59,11 @@ def __init__(self, low_res, high_res, output, mask): @property def output(self): """Get the output for the batch. - Output predicted by the neural net can be different - than the high_res when doing moment estimation. - For ex: output may be (high_res)**2 - We distinguish output from high_res since it may not be - possible to recover high_res from output.""" + Output predicted by the neural net can be different + than the high_res when doing moment estimation. + For ex: output may be (high_res)**2 + We distinguish output from high_res since it may not be + possible to recover high_res from output.""" return self._output @property @@ -69,10 +73,15 @@ def mask(self): # pylint: disable=W0613 @staticmethod - def make_output(low_res, high_res, - s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None, - t_enhance_mode='constant'): + def make_output( + low_res, + high_res, + s_enhance=None, + t_enhance=None, + model_mom1=None, + output_features_ind=None, + t_enhance_mode='constant', + ): """Make custom batch output Parameters @@ -110,9 +119,13 @@ def make_output(low_res, high_res, # pylint: disable=E1130 @staticmethod - def make_mask(high_res, - s_padding=None, t_padding=None, - end_t_padding=False, t_enhance=None): + def make_mask( + high_res, + s_padding=None, + t_padding=None, + end_t_padding=False, + t_enhance=None, + ): """Make mask for output. The mask is used to ensure consistency when training conditional moments. @@ -183,19 +196,23 @@ def make_mask(high_res, # pylint: disable=W0613 @classmethod - def get_coarse_batch(cls, high_res, - s_enhance, t_enhance=1, - temporal_coarsening_method='subsample', - temporal_enhancing_method='constant', - output_features_ind=None, - output_features=None, - training_features=None, - smoothing=None, - smoothing_ignore=None, - model_mom1=None, - s_padding=None, - t_padding=None, - end_t_padding=False): + def get_coarse_batch( + cls, + high_res, + s_enhance, + t_enhance=1, + temporal_coarsening_method='subsample', + temporal_enhancing_method='constant', + output_features_ind=None, + output_features=None, + training_features=None, + smoothing=None, + smoothing_ignore=None, + model_mom1=None, + s_padding=None, + t_padding=None, + end_t_padding=False, + ): """Coarsen high res data and return Batch with high res and low res data @@ -267,18 +284,26 @@ def get_coarse_batch(cls, high_res, smoothing_ignore = [] if t_enhance != 1: - low_res = temporal_coarsening(low_res, t_enhance, - temporal_coarsening_method) + low_res = temporal_coarsening( + low_res, t_enhance, temporal_coarsening_method + ) - low_res = smooth_data(low_res, training_features, smoothing_ignore, - smoothing) + low_res = smooth_data( + low_res, training_features, smoothing_ignore, smoothing + ) high_res = cls.reduce_features(high_res, output_features_ind) - output = cls.make_output(low_res, high_res, - s_enhance, t_enhance, - model_mom1, output_features_ind, - temporal_enhancing_method) - mask = cls.make_mask(high_res, - s_padding, t_padding, end_t_padding, t_enhance) + output = cls.make_output( + low_res, + high_res, + s_enhance, + t_enhance, + model_mom1, + output_features_ind, + temporal_enhancing_method, + ) + mask = cls.make_mask( + high_res, s_padding, t_padding, end_t_padding, t_enhance + ) batch = cls(low_res, high_res, output, mask) return batch @@ -289,10 +314,15 @@ class BatchMom1SF(BatchMom1): of subfilter vel""" @staticmethod - def make_output(low_res, high_res, - s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None, - t_enhance_mode='constant'): + def make_output( + low_res, + high_res, + s_enhance=None, + t_enhance=None, + model_mom1=None, + output_features_ind=None, + t_enhance_mode='constant', + ): """Make custom batch output Parameters @@ -328,11 +358,10 @@ def make_output(low_res, high_res, SF = HR - LR """ # Remove LR from HR - enhanced_lr = spatial_simple_enhancing(low_res, - s_enhance=s_enhance) - enhanced_lr = temporal_simple_enhancing(enhanced_lr, - t_enhance=t_enhance, - mode=t_enhance_mode) + enhanced_lr = spatial_simple_enhancing(low_res, s_enhance=s_enhance) + enhanced_lr = temporal_simple_enhancing( + enhanced_lr, t_enhance=t_enhance, mode=t_enhance_mode + ) enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind) return high_res - enhanced_lr @@ -343,10 +372,15 @@ class BatchMom2(BatchMom1): moment""" @staticmethod - def make_output(low_res, high_res, - s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None, - t_enhance_mode='constant'): + def make_output( + low_res, + high_res, + s_enhance=None, + t_enhance=None, + model_mom1=None, + output_features_ind=None, + t_enhance_mode='constant', + ): """Make custom batch output Parameters @@ -382,7 +416,7 @@ def make_output(low_res, high_res, """ # Remove first moment from HR and square it out = model_mom1._tf_generate(low_res).numpy() - return (high_res - out)**2 + return (high_res - out) ** 2 class BatchMom2Sep(BatchMom1): @@ -390,10 +424,15 @@ class BatchMom2Sep(BatchMom1): separate from first moment""" @staticmethod - def make_output(low_res, high_res, - s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None, - t_enhance_mode='constant'): + def make_output( + low_res, + high_res, + s_enhance=None, + t_enhance=None, + model_mom1=None, + output_features_ind=None, + t_enhance_mode='constant', + ): """Make custom batch output Parameters @@ -427,12 +466,18 @@ def make_output(low_res, high_res, (batch_size, spatial_1, spatial_2, temporal, features) HR is high-res """ - return super(BatchMom2Sep, - BatchMom2Sep).make_output(low_res, high_res, - s_enhance, t_enhance, - model_mom1, - output_features_ind, - t_enhance_mode)**2 + return ( + super(BatchMom2Sep, BatchMom2Sep).make_output( + low_res, + high_res, + s_enhance, + t_enhance, + model_mom1, + output_features_ind, + t_enhance_mode, + ) + ** 2 + ) class BatchMom2SF(BatchMom1): @@ -440,10 +485,15 @@ class BatchMom2SF(BatchMom1): of subfilter vel""" @staticmethod - def make_output(low_res, high_res, - s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None, - t_enhance_mode='constant'): + def make_output( + low_res, + high_res, + s_enhance=None, + t_enhance=None, + model_mom1=None, + output_features_ind=None, + t_enhance_mode='constant', + ): """Make custom batch output Parameters @@ -480,13 +530,12 @@ def make_output(low_res, high_res, """ # Remove LR and first moment from HR and square it out = model_mom1._tf_generate(low_res).numpy() - enhanced_lr = spatial_simple_enhancing(low_res, - s_enhance=s_enhance) - enhanced_lr = temporal_simple_enhancing(enhanced_lr, - t_enhance=t_enhance, - mode=t_enhance_mode) + enhanced_lr = spatial_simple_enhancing(low_res, s_enhance=s_enhance) + enhanced_lr = temporal_simple_enhancing( + enhanced_lr, t_enhance=t_enhance, mode=t_enhance_mode + ) enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind) - return (high_res - enhanced_lr - out)**2 + return (high_res - enhanced_lr - out) ** 2 class BatchMom2SepSF(BatchMom1SF): @@ -494,10 +543,15 @@ class BatchMom2SepSF(BatchMom1SF): of subfilter vel separate from first moment""" @staticmethod - def make_output(low_res, high_res, - s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None, - t_enhance_mode='constant'): + def make_output( + low_res, + high_res, + s_enhance=None, + t_enhance=None, + model_mom1=None, + output_features_ind=None, + t_enhance_mode='constant', + ): """Make custom batch output Parameters @@ -533,12 +587,18 @@ def make_output(low_res, high_res, SF = HR - LR """ # Remove LR from HR and square it - return super(BatchMom2SepSF, - BatchMom2SepSF).make_output(low_res, high_res, - s_enhance, t_enhance, - model_mom1, - output_features_ind, - t_enhance_mode)**2 + return ( + super(BatchMom2SepSF, BatchMom2SepSF).make_output( + low_res, + high_res, + s_enhance, + t_enhance, + model_mom1, + output_features_ind, + t_enhance_mode, + ) + ** 2 + ) class ValidationDataMom1(ValidationData): @@ -547,18 +607,27 @@ class ValidationDataMom1(ValidationData): # Classes to use for handling an individual batch obj. BATCH_CLASS = BatchMom1 - def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, - temporal_coarsening_method='subsample', - temporal_enhancing_method='constant', - output_features_ind=None, - output_features=None, - smoothing=None, smoothing_ignore=None, - model_mom1=None, - s_padding=None, t_padding=None, end_t_padding=False): + def __init__( + self, + data_handlers, + batch_size=8, + s_enhance=3, + t_enhance=1, + temporal_coarsening_method='subsample', + temporal_enhancing_method='constant', + output_features_ind=None, + output_features=None, + smoothing=None, + smoothing_ignore=None, + model_mom1=None, + s_padding=None, + t_padding=None, + end_t_padding=False, + ): """ Parameters ---------- - handlers : list[DataHandler] + data_handlers : list[DataHandler] List of DataHandler instances batch_size : int Size of validation data batches @@ -615,12 +684,11 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, handler_shapes = np.array([d.sample_shape for d in data_handlers]) assert np.all(handler_shapes[0] == handler_shapes) - self.handlers = data_handlers + self.data_handlers = data_handlers self.batch_size = batch_size self.sample_shape = handler_shapes[0] self.val_indices = self._get_val_indices() - self.max = np.ceil( - len(self.val_indices) / (batch_size)) + self.max = np.ceil(len(self.val_indices) / (batch_size)) self.s_enhance = s_enhance self.t_enhance = t_enhance self.s_padding = s_padding @@ -651,7 +719,8 @@ def batch_next(self, high_res): batch : Batch """ return self.BATCH_CLASS.get_coarse_batch( - high_res, self.s_enhance, + high_res, + self.s_enhance, t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, temporal_enhancing_method=self.temporal_enhancing_method, @@ -662,7 +731,8 @@ def batch_next(self, high_res): model_mom1=self.model_mom1, s_padding=self.s_padding, t_padding=self.t_padding, - end_t_padding=self.end_t_padding) + end_t_padding=self.end_t_padding, + ) class BatchHandlerMom1(BatchHandler): @@ -673,14 +743,32 @@ class BatchHandlerMom1(BatchHandler): BATCH_CLASS = BatchMom1 DATA_HANDLER_CLASS = None - def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, - means=None, stds=None, norm=True, n_batches=10, - temporal_coarsening_method='subsample', - temporal_enhancing_method='constant', stdevs_file=None, - means_file=None, overwrite_stats=False, smoothing=None, - smoothing_ignore=None, stats_workers=None, norm_workers=None, - load_workers=None, max_workers=None, model_mom1=None, - s_padding=None, t_padding=None, end_t_padding=False): + def __init__( + self, + data_handlers, + batch_size=8, + s_enhance=3, + t_enhance=1, + means=None, + stds=None, + norm=True, + n_batches=10, + temporal_coarsening_method='subsample', + temporal_enhancing_method='constant', + stdevs_file=None, + means_file=None, + overwrite_stats=False, + smoothing=None, + smoothing_ignore=None, + stats_workers=None, + norm_workers=None, + load_workers=None, + max_workers=None, + model_mom1=None, + s_padding=None, + t_padding=None, + end_t_padding=False, + ): """ Parameters ---------- @@ -769,7 +857,7 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, if max_workers is not None: norm_workers = stats_workers = load_workers = max_workers - msg = ('All data handlers must have the same sample_shape') + msg = 'All data handlers must have the same sample_shape' handler_shapes = np.array([d.sample_shape for d in data_handlers]) assert np.all(handler_shapes[0] == handler_shapes), msg @@ -798,22 +886,27 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, self.overwrite_stats = overwrite_stats self.smoothing = smoothing self.smoothing_ignore = smoothing_ignore or [] - self.smoothed_features = [f for f in self.training_features - if f not in self.smoothing_ignore] + self.smoothed_features = [ + f for f in self.training_features if f not in self.smoothing_ignore + ] self._stats_workers = stats_workers self._norm_workers = norm_workers self._load_workers = load_workers self.model_mom1 = model_mom1 - logger.info(f'Initializing BatchHandler with smoothing={smoothing}. ' - f'Using stats_workers={self.stats_workers}, ' - f'norm_workers={self.norm_workers}, ' - f'load_workers={self.load_workers}.') + logger.info( + f'Initializing BatchHandler with smoothing={smoothing}. ' + f'Using stats_workers={self.stats_workers}, ' + f'norm_workers={self.norm_workers}, ' + f'load_workers={self.load_workers}.' + ) now = dt.now() self.parallel_load() - logger.debug(f'Finished loading data of shape {self.shape} ' - f'for BatchHandler in {dt.now() - now}.') + logger.debug( + f'Finished loading data of shape {self.shape} ' + f'for BatchHandler in {dt.now() - now}.' + ) log_mem(logger, log_level='INFO') if norm: @@ -822,8 +915,10 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, logger.debug('Getting validation data for BatchHandler.') self.val_data = self.VAL_CLASS( - data_handlers, batch_size=batch_size, - s_enhance=s_enhance, t_enhance=t_enhance, + data_handlers, + batch_size=batch_size, + s_enhance=s_enhance, + t_enhance=t_enhance, temporal_coarsening_method=temporal_coarsening_method, temporal_enhancing_method=temporal_enhancing_method, output_features_ind=self.output_features_ind, @@ -833,7 +928,8 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, model_mom1=self.model_mom1, s_padding=self.s_padding, t_padding=self.t_padding, - end_t_padding=self.end_t_padding) + end_t_padding=self.end_t_padding, + ) logger.info('Finished initializing BatchHandler.') log_mem(logger, log_level='INFO') @@ -852,16 +948,25 @@ def __next__(self): handler_index = np.random.randint(0, len(self.data_handlers)) self.current_handler_index = handler_index handler = self.data_handlers[handler_index] - high_res = np.zeros((self.batch_size, self.sample_shape[0], - self.sample_shape[1], self.sample_shape[2], - self.shape[-1]), dtype=np.float32) + high_res = np.zeros( + ( + self.batch_size, + self.sample_shape[0], + self.sample_shape[1], + self.sample_shape[2], + self.shape[-1], + ), + dtype=np.float32, + ) for i in range(self.batch_size): high_res[i, ...] = handler.get_next() self.current_batch_indices.append(handler.current_obs_index) batch = self.BATCH_CLASS.get_coarse_batch( - high_res, self.s_enhance, t_enhance=self.t_enhance, + high_res, + self.s_enhance, + t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, temporal_enhancing_method=self.temporal_enhancing_method, output_features_ind=self.output_features_ind, @@ -872,7 +977,8 @@ def __next__(self): model_mom1=self.model_mom1, s_padding=self.s_padding, t_padding=self.t_padding, - end_t_padding=self.end_t_padding) + end_t_padding=self.end_t_padding, + ) self._i += 1 return batch @@ -885,17 +991,23 @@ class SpatialBatchHandlerMom1(BatchHandlerMom1): def __next__(self): if self._i < self.n_batches: - handler_index = np.random.randint( - 0, len(self.data_handlers)) + handler_index = np.random.randint(0, len(self.data_handlers)) handler = self.data_handlers[handler_index] - high_res = np.zeros((self.batch_size, self.sample_shape[0], - self.sample_shape[1], self.shape[-1]), - dtype=np.float32) + high_res = np.zeros( + ( + self.batch_size, + self.sample_shape[0], + self.sample_shape[1], + self.shape[-1], + ), + dtype=np.float32, + ) for i in range(self.batch_size): high_res[i, ...] = handler.get_next()[..., 0, :] batch = self.BATCH_CLASS.get_coarse_batch( - high_res, self.s_enhance, + high_res, + self.s_enhance, output_features_ind=self.output_features_ind, training_features=self.training_features, smoothing=self.smoothing, @@ -903,7 +1015,8 @@ def __next__(self): model_mom1=self.model_mom1, s_padding=self.s_padding, t_padding=self.t_padding, - end_t_padding=self.end_t_padding) + end_t_padding=self.end_t_padding, + ) self._i += 1 return batch @@ -914,35 +1027,41 @@ def __next__(self): class ValidationDataMom1SF(ValidationDataMom1): """Iterator for validation data for first conditional moment of subfilter velocity""" + BATCH_CLASS = BatchMom1SF class ValidationDataMom2(ValidationDataMom1): """Iterator for subfilter validation data for second conditional moment""" + BATCH_CLASS = BatchMom2 class ValidationDataMom2Sep(ValidationDataMom1): """Iterator for subfilter validation data for second conditional moment separate from first moment""" + BATCH_CLASS = BatchMom2Sep class ValidationDataMom2SF(ValidationDataMom1): """Iterator for validation data for second conditional moment of subfilter velocity""" + BATCH_CLASS = BatchMom2SF class ValidationDataMom2SepSF(ValidationDataMom1): """Iterator for validation data for second conditional moment of subfilter velocity separate from first moment""" + BATCH_CLASS = BatchMom2SepSF class BatchHandlerMom1SF(BatchHandlerMom1): """Sup3r batch handling class for first conditional moment of subfilter velocity""" + VAL_CLASS = ValidationDataMom1SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -950,12 +1069,14 @@ class BatchHandlerMom1SF(BatchHandlerMom1): class SpatialBatchHandlerMom1SF(SpatialBatchHandlerMom1): """Sup3r spatial batch handling class for first conditional moment of subfilter velocity""" + VAL_CLASS = ValidationDataMom1SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS class BatchHandlerMom2(BatchHandlerMom1): """Sup3r batch handling class for second conditional moment""" + VAL_CLASS = ValidationDataMom2 BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -963,12 +1084,14 @@ class BatchHandlerMom2(BatchHandlerMom1): class BatchHandlerMom2Sep(BatchHandlerMom1): """Sup3r batch handling class for second conditional moment separate from first moment""" + VAL_CLASS = ValidationDataMom2Sep BATCH_CLASS = VAL_CLASS.BATCH_CLASS class SpatialBatchHandlerMom2(SpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment""" + VAL_CLASS = ValidationDataMom2 BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -976,6 +1099,7 @@ class SpatialBatchHandlerMom2(SpatialBatchHandlerMom1): class SpatialBatchHandlerMom2Sep(SpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment separate from first moment""" + VAL_CLASS = ValidationDataMom2Sep BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -983,6 +1107,7 @@ class SpatialBatchHandlerMom2Sep(SpatialBatchHandlerMom1): class BatchHandlerMom2SF(BatchHandlerMom1): """Sup3r batch handling class for second conditional moment of subfilter velocity""" + VAL_CLASS = ValidationDataMom2SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -990,6 +1115,7 @@ class BatchHandlerMom2SF(BatchHandlerMom1): class BatchHandlerMom2SepSF(BatchHandlerMom1): """Sup3r batch handling class for second conditional moment of subfilter velocity separate from first moment""" + VAL_CLASS = ValidationDataMom2SepSF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -997,6 +1123,7 @@ class BatchHandlerMom2SepSF(BatchHandlerMom1): class SpatialBatchHandlerMom2SF(SpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment of subfilter velocity""" + VAL_CLASS = ValidationDataMom2SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -1004,5 +1131,6 @@ class SpatialBatchHandlerMom2SF(SpatialBatchHandlerMom1): class SpatialBatchHandlerMom2SepSF(SpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment of subfilter velocity separate from first moment""" + VAL_CLASS = ValidationDataMom2SepSF BATCH_CLASS = VAL_CLASS.BATCH_CLASS diff --git a/sup3r/preprocessing/data_handling.py b/sup3r/preprocessing/data_handling.py deleted file mode 100644 index 12b1d2aa6..000000000 --- a/sup3r/preprocessing/data_handling.py +++ /dev/null @@ -1,3509 +0,0 @@ -"""Sup3r preprocessing module. -@author: bbenton -""" - -import copy -import glob -import logging -import os -import pickle -import warnings -from abc import abstractmethod -from concurrent.futures import ThreadPoolExecutor, as_completed -from datetime import datetime as dt -from fnmatch import fnmatch -from typing import ClassVar - -import numpy as np -import pandas as pd -import xarray as xr -from rex import MultiFileNSRDBX, MultiFileWindX, Resource -from rex.utilities import log_mem -from rex.utilities.fun_utils import get_fun_call_str -from scipy.ndimage.filters import gaussian_filter -from scipy.spatial import KDTree -from scipy.stats import mode - -from sup3r.bias.bias_transforms import get_spatial_bc_factors -from sup3r.preprocessing.feature_handling import ( - BVFreqMon, - BVFreqSquaredH5, - BVFreqSquaredNC, - ClearSkyRatioCC, - ClearSkyRatioH5, - CloudMaskH5, - Feature, - FeatureHandler, - InverseMonNC, - LatLonH5, - LatLonNC, - PotentialTempNC, - PressureNC, - Rews, - Shear, - Tas, - TasMax, - TasMin, - TempNC, - TempNCforCC, - TopoH5, - UWind, - VWind, - WinddirectionNC, - WindspeedNC, -) -from sup3r.utilities import ModuleName -from sup3r.utilities.cli import BaseCLI -from sup3r.utilities.interpolation import Interpolator -from sup3r.utilities.utilities import ( - daily_temporal_coarsening, - estimate_max_workers, - get_chunk_slices, - get_raster_shape, - get_source_type, - get_time_dim_name, - ignore_case_path_fetch, - np_to_pd_times, - spatial_coarsening, - uniform_box_sampler, - uniform_time_sampler, - weighted_box_sampler, - weighted_time_sampler, -) - -np.random.seed(42) - -logger = logging.getLogger(__name__) - - -class InputMixIn: - """MixIn class with properties and methods for handling the spatiotemporal - data domain to extract from source data.""" - - def __init__( - self, - target, - shape, - raster_file=None, - raster_index=None, - temporal_slice=slice(None, None, 1), - ): - """Provide properties of the spatiotemporal data domain - - Parameters - ---------- - target : tuple - (lat, lon) lower left corner of raster. Either need target+shape or - raster_file. - shape : tuple - (rows, cols) grid size. Either need target+shape or raster_file. - raster_file : str | None - File for raster_index array for the corresponding target and shape. - If specified the raster_index will be loaded from the file if it - exists or written to the file if it does not yet exist. If None and - raster_index is not provided raster_index will be calculated - directly. Either need target+shape, raster_file, or raster_index - input. - raster_index : list - List of tuples or slices. Used as an alternative to computing the - raster index from target+shape or loading the raster index from - file - temporal_slice : slice - Slice specifying extent and step of temporal extraction. e.g. - slice(start, stop, time_pruning). If equal to slice(None, None, 1) - the full time dimension is selected. - """ - self.raster_file = raster_file - self.target = target - self.grid_shape = shape - self.raster_index = raster_index - self.temporal_slice = temporal_slice - self.lat_lon = None - self.overwrite_ti_cache = False - self.max_workers = None - self._ti_workers = None - self._raw_time_index = None - self._raw_tsteps = None - self._time_index = None - self._time_index_file = None - self._file_paths = None - self._cache_pattern = None - self._invert_lat = None - self._raw_lat_lon = None - self._full_raw_lat_lon = None - self._single_ts_files = None - self._worker_attrs = ['ti_workers'] - self.res_kwargs = {} - - @property - def raw_tsteps(self): - """Get number of time steps for all input files""" - if self._raw_tsteps is None: - if self.single_ts_files: - self._raw_tsteps = len(self.file_paths) - else: - self._raw_tsteps = len(self.raw_time_index) - return self._raw_tsteps - - @property - def single_ts_files(self): - """Check if there is a file for each time step, in which case we can - send a subset of files to the data handler according to ti_pad_slice""" - if self._single_ts_files is None: - logger.debug('Checking if input files are single timestep.') - t_steps = self.get_time_index(self.file_paths[:1], max_workers=1) - check = ( - len(self._file_paths) == len(self.raw_time_index) - and t_steps is not None - and len(t_steps) == 1 - ) - self._single_ts_files = check - return self._single_ts_files - - @staticmethod - def get_capped_workers(max_workers_cap, max_workers): - """Get max number of workers for a given job. Capped to global max - workers if specified - - Parameters - ---------- - max_workers_cap : int | None - Cap for job specific max_workers - max_workers : int | None - Job specific max_workers - - Returns - ------- - max_workers : int | None - job specific max_workers capped by max_workers_cap if provided - """ - if max_workers is None and max_workers_cap is None: - return max_workers - elif max_workers_cap is not None and max_workers is None: - return max_workers_cap - elif max_workers is not None and max_workers_cap is None: - return max_workers - else: - return np.min((max_workers_cap, max_workers)) - - def cap_worker_args(self, max_workers): - """Cap all workers args by max_workers""" - for v in self._worker_attrs: - capped_val = self.get_capped_workers(getattr(self, v), max_workers) - setattr(self, v, capped_val) - - @classmethod - @abstractmethod - def get_full_domain(cls, file_paths): - """Get full lat/lon grid for when target + shape are not specified""" - - @classmethod - @abstractmethod - def get_lat_lon(cls, file_paths, raster_index, invert_lat=False): - """Get lat/lon grid for requested target and shape""" - - @abstractmethod - def get_time_index(self, file_paths, max_workers=None, **kwargs): - """Get raw time index for source data""" - - @property - def input_file_info(self): - """Method to provide info about files in log output. Since NETCDF files - have single time slices printing out all the file paths is just a text - dump without much info. - - Returns - ------- - str - message to append to log output that does not include a huge info - dump of file paths - """ - msg = ( - f'source files with dates from {self.raw_time_index[0]} to ' - f'{self.raw_time_index[-1]}' - ) - return msg - - @property - def temporal_slice(self): - """Get temporal range to extract from full dataset""" - return self._temporal_slice - - @temporal_slice.setter - def temporal_slice(self, temporal_slice): - """Make sure temporal_slice is a slice. Need to do this because json - cannot save slices so we can instead save as list and then convert. - - Parameters - ---------- - temporal_slice : tuple | list | slice - Time range to extract from input data. If a list or tuple it will - be concerted to a slice. Tuple or list must have at least two - elements and no more than three, corresponding to the inputs of - slice() - """ - msg = 'temporal_slice must be tuple, list, or slice' - assert isinstance(temporal_slice, (tuple, list, slice)), msg - if isinstance(temporal_slice, slice): - self._temporal_slice = temporal_slice - else: - check = len(temporal_slice) <= 3 - msg = ( - 'If providing list or tuple for temporal_slice length must ' - 'be <= 3' - ) - assert check, msg - self._temporal_slice = slice(*temporal_slice) - if self._temporal_slice.step is None: - self._temporal_slice = slice( - self._temporal_slice.start, self._temporal_slice.stop, 1 - ) - if self._temporal_slice.start is None: - self._temporal_slice = slice( - 0, self._temporal_slice.stop, self._temporal_slice.step - ) - - @property - def file_paths(self): - """Get file paths for input data""" - return self._file_paths - - @file_paths.setter - def file_paths(self, file_paths): - """Set file paths attr and do initial glob / sort - - Parameters - ---------- - file_paths : str | list - A list of files to extract raster data from. Each file must have - the same number of timesteps. Can also pass a string with a - unix-style file path which will be passed through glob.glob - """ - self._file_paths = file_paths - if isinstance(self._file_paths, str): - if '*' in file_paths: - self._file_paths = glob.glob(self._file_paths) - else: - self._file_paths = [self._file_paths] - - msg = ( - 'No valid files provided to DataHandler. ' - f'Received file_paths={file_paths}. Aborting.' - ) - assert file_paths is not None and len(self._file_paths) > 0, msg - - self._file_paths = sorted(self._file_paths) - - @property - def ti_workers(self): - """Get max number of workers for computing time index""" - if self._ti_workers is None: - self._ti_workers = len(self._file_paths) - return self._ti_workers - - @ti_workers.setter - def ti_workers(self, val): - """Set max number of workers for computing time index""" - self._ti_workers = val - - @property - def cache_pattern(self): - """Get correct cache file pattern for formatting. - - Returns - ------- - _cache_pattern : str - The cache file pattern with formatting keys included. - """ - if self._cache_pattern is not None: - if '.pkl' not in self._cache_pattern: - self._cache_pattern += '.pkl' - if '{feature}' not in self._cache_pattern: - self._cache_pattern = self._cache_pattern.replace( - '.pkl', '_{feature}.pkl' - ) - basedir = os.path.dirname(self._cache_pattern) - if not os.path.exists(basedir): - os.makedirs(basedir, exist_ok=True) - return self._cache_pattern - - @cache_pattern.setter - def cache_pattern(self, cache_pattern): - """Update the cache file pattern""" - self._cache_pattern = cache_pattern - - @property - def need_full_domain(self): - """Check whether we need to get the full lat/lon grid to determine - target and shape values""" - no_raster_file = self.raster_file is None or not os.path.exists( - self.raster_file - ) - no_target_shape = self._target is None or self._grid_shape is None - need_full = no_raster_file and no_target_shape - - if need_full: - logger.info( - 'Target + shape not specified. Getting full domain ' - f'for {self.file_paths[0]}.' - ) - - return need_full - - @property - def full_raw_lat_lon(self): - """Get the full lat/lon grid without doing any latitude inversion""" - if self._full_raw_lat_lon is None and self.need_full_domain: - self._full_raw_lat_lon = self.get_full_domain(self.file_paths[:1]) - return self._full_raw_lat_lon - - @property - def raw_lat_lon(self): - """Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon - array with same ordering in last dimension. This returns the gid - without any lat inversion. - - Returns - ------- - ndarray - """ - raster_file_exists = self.raster_file is not None and os.path.exists( - self.raster_file - ) - - if self.full_raw_lat_lon is not None and raster_file_exists: - self._raw_lat_lon = self.full_raw_lat_lon[self.raster_index] - - elif self.full_raw_lat_lon is not None and not raster_file_exists: - self._raw_lat_lon = self.full_raw_lat_lon - - if self._raw_lat_lon is None: - self._raw_lat_lon = self.get_lat_lon( - self.file_paths[0:1], self.raster_index, invert_lat=False - ) - return self._raw_lat_lon - - @property - def lat_lon(self): - """Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon - array with same ordering in last dimension. This ensures that the - lower left hand corner of the domain is given by lat_lon[-1, 0] - - Returns - ------- - ndarray - """ - if self._lat_lon is None: - self._lat_lon = self.raw_lat_lon - if self.invert_lat: - self._lat_lon = self._lat_lon[::-1] - return self._lat_lon - - @lat_lon.setter - def lat_lon(self, lat_lon): - """Update lat lon""" - self._lat_lon = lat_lon - - @property - def latitude(self): - """Return latitude array""" - return self.lat_lon[..., 0] - - @property - def longitude(self): - """Return longitude array""" - return self.lat_lon[..., 1] - - @property - def invert_lat(self): - """Whether to invert the latitude axis during data extraction. This is - to enforce a descending latitude ordering so that the lower left corner - of the grid is at idx=(-1, 0) instead of idx=(0, 0)""" - if self._invert_lat is None: - lat_lon = self.raw_lat_lon - self._invert_lat = not self.lats_are_descending(lat_lon) - return self._invert_lat - - @property - def target(self): - """Get lower left corner of raster - - Returns - ------- - _target: tuple - (lat, lon) lower left corner of raster. - """ - if self._target is None: - lat_lon = self.lat_lon - if not self.lats_are_descending(lat_lon): - self._target = tuple(lat_lon[0, 0, :]) - else: - self._target = tuple(lat_lon[-1, 0, :]) - return self._target - - @target.setter - def target(self, target): - """Update target property""" - self._target = target - - @classmethod - def lats_are_descending(cls, lat_lon): - """Check if latitudes are in descending order (i.e. the target - coordinate is already at the bottom left corner) - - Parameters - ---------- - lat_lon : np.ndarray - Lat/Lon array with shape (n_lats, n_lons, 2) - - Returns - ------- - bool - """ - return lat_lon[-1, 0, 0] < lat_lon[0, 0, 0] - - @property - def grid_shape(self): - """Get shape of raster - - Returns - ------- - _grid_shape: tuple - (rows, cols) grid size. - """ - if self._grid_shape is None: - self._grid_shape = self.lat_lon.shape[:-1] - return self._grid_shape - - @grid_shape.setter - def grid_shape(self, grid_shape): - """Update grid_shape property""" - self._grid_shape = grid_shape - - @property - def source_type(self): - """Get data type for source files. Either nc or h5""" - return get_source_type(self.file_paths) - - @property - def time_index_file(self): - """Get time index file path""" - if self.source_type == 'h5': - return None - - if self.cache_pattern is not None and self._time_index_file is None: - basename = self.cache_pattern.replace('{times}', '') - basename = basename.replace('{shape}', str(len(self.file_paths))) - basename = basename.replace('_{target}', '') - basename = basename.replace('{feature}', 'time_index') - tmp = basename.split('_') - if tmp[-2].isdigit() and tmp[-1].strip('.pkl').isdigit(): - basename = '_'.join(tmp[:-1]) + '.pkl' - self._time_index_file = basename - return self._time_index_file - - @property - def raw_time_index(self): - """Time index for input data without time pruning. This is the base - time index for the raw input data.""" - - if self._raw_time_index is None: - check = ( - self.time_index_file is not None - and os.path.exists(self.time_index_file) - and not self.overwrite_ti_cache - ) - if check: - logger.debug( - 'Loading raw_time_index from ' f'{self.time_index_file}' - ) - with open(self.time_index_file, 'rb') as f: - self._raw_time_index = pd.DatetimeIndex(pickle.load(f)) - else: - self._raw_time_index = self._build_and_cache_time_index() - - check = ( - self._raw_time_index is not None - and (self._raw_time_index.hour == 12).all() - ) - if check: - self._raw_time_index -= pd.Timedelta(12, 'h') - elif self._raw_time_index is None: - self._raw_time_index = [None, None] - - if self._single_ts_files: - self.time_index_conflict_check() - return self._raw_time_index - - def time_index_conflict_check(self): - """Check if the number of input files and the length of the time index - is the same""" - msg = ( - f'Number of time steps ({len(self._raw_time_index)}) and files ' - f'({self.raw_tsteps}) conflict!' - ) - check = len(self._raw_time_index) == self.raw_tsteps - assert check, msg - - def _build_and_cache_time_index(self): - """Build time index and cache if time_index_file is not None""" - now = dt.now() - logger.debug( - f'Getting time index for {len(self.file_paths)} ' - f'input files. Using ti_workers={self.ti_workers}' - f' and res_kwargs={self.res_kwargs}' - ) - self._raw_time_index = self.get_time_index( - self.file_paths, max_workers=self.ti_workers, **self.res_kwargs - ) - - if self.time_index_file is not None: - logger.debug(f'Saved raw_time_index to {self.time_index_file}') - with open(self.time_index_file, 'wb') as f: - pickle.dump(self._raw_time_index, f) - logger.debug(f'Built full time index in {dt.now() - now} seconds.') - return self._raw_time_index - - @property - def time_index(self): - """Time index for input data with time pruning. This is the raw time - index with a cropped range and time step applied.""" - if self._time_index is None: - self._time_index = self.raw_time_index[self.temporal_slice] - return self._time_index - - @time_index.setter - def time_index(self, time_index): - """Update time index""" - self._time_index = time_index - - @property - def time_freq_hours(self): - """Get the time frequency in hours as a float""" - ti_deltas = self.raw_time_index - np.roll(self.raw_time_index, 1) - ti_deltas_hours = pd.Series(ti_deltas).dt.total_seconds()[1:-1] / 3600 - time_freq = float(mode(ti_deltas_hours).mode) - return time_freq - - @property - def timestamp_0(self): - """Get a string timestamp for the first time index value with the - format YYYYMMDDHHMMSS""" - - time_stamp = self.time_index[0] - yyyy = str(time_stamp.year) - mm = str(time_stamp.month).zfill(2) - dd = str(time_stamp.day).zfill(2) - hh = str(time_stamp.hour).zfill(2) - min = str(time_stamp.minute).zfill(2) - ss = str(time_stamp.second).zfill(2) - ts0 = yyyy + mm + dd + hh + min + ss - return ts0 - - @property - def timestamp_1(self): - """Get a string timestamp for the last time index value with the - format YYYYMMDDHHMMSS""" - - time_stamp = self.time_index[-1] - yyyy = str(time_stamp.year) - mm = str(time_stamp.month).zfill(2) - dd = str(time_stamp.day).zfill(2) - hh = str(time_stamp.hour).zfill(2) - min = str(time_stamp.minute).zfill(2) - ss = str(time_stamp.second).zfill(2) - ts1 = yyyy + mm + dd + hh + min + ss - return ts1 - - -class DataHandler(FeatureHandler, InputMixIn): - """Sup3r data handling and extraction for low-res source data or for - artificially coarsened high-res source data for training. - - The sup3r data handler class is based on a 4D numpy array of shape: - (spatial_1, spatial_2, temporal, features) - """ - - # list of features / feature name patterns that are input to the generative - # model but are not part of the synthetic output and are not sent to the - # discriminator. These are case-insensitive and follow the Unix shell-style - # wildcard format. - TRAIN_ONLY_FEATURES = ( - 'BVF*', - 'inversemoninobukhovlength_*', - 'RMOL', - 'topography', - ) - - def __init__( - self, - file_paths, - features, - target=None, - shape=None, - max_delta=20, - temporal_slice=slice(None, None, 1), - hr_spatial_coarsen=None, - time_roll=0, - val_split=0.0, - sample_shape=(10, 10, 1), - raster_file=None, - raster_index=None, - shuffle_time=False, - time_chunk_size=None, - cache_pattern=None, - overwrite_cache=False, - overwrite_ti_cache=False, - load_cached=False, - train_only_features=None, - handle_features=None, - single_ts_files=None, - mask_nan=False, - worker_kwargs=None, - res_kwargs=None, - ): - """ - Parameters - ---------- - file_paths : str | list - A single source h5 wind file to extract raster data from or a list - of netcdf files with identical grid. The string can be a unix-style - file path which will be passed through glob.glob - features : list - list of features to extract from the provided data - target : tuple - (lat, lon) lower left corner of raster. Either need target+shape or - raster_file. - shape : tuple - (rows, cols) grid size. Either need target+shape or raster_file. - max_delta : int, optional - Optional maximum limit on the raster shape that is retrieved at - once. If shape is (20, 20) and max_delta=10, the full raster will - be retrieved in four chunks of (10, 10). This helps adapt to - non-regular grids that curve over large distances, by default 20 - temporal_slice : slice - Slice specifying extent and step of temporal extraction. e.g. - slice(start, stop, time_pruning). If equal to slice(None, None, 1) - the full time dimension is selected. - hr_spatial_coarsen : int | None - Optional input to coarsen the high-resolution spatial field. This - can be used if (for example) you have 2km source data, but you want - the final high res prediction target to be 4km resolution, then - hr_spatial_coarsen would be 2 so that the GAN is trained on - aggregated 4km high-res data. - time_roll : int - The number of places by which elements are shifted in the time - axis. Can be used to convert data to different timezones. This is - passed to np.roll(a, time_roll, axis=2) and happens AFTER the - temporal_slice operation. - val_split : float32 - Fraction of data to store for validation - sample_shape : tuple - Size of spatial and temporal domain used in a single high-res - observation for batching - raster_file : str | None - File for raster_index array for the corresponding target and shape. - If specified the raster_index will be loaded from the file if it - exists or written to the file if it does not yet exist. If None and - raster_index is not provided raster_index will be calculated - directly. Either need target+shape, raster_file, or raster_index - input. - raster_index : list - List of tuples or slices. Used as an alternative to computing the - raster index from target+shape or loading the raster index from - file - shuffle_time : bool - Whether to shuffle time indices before validation split - time_chunk_size : int - Size of chunks to split time dimension into for parallel data - extraction. If running in serial this can be set to the size of the - full time index for best performance. - cache_pattern : str | None - Pattern for files for saving feature data. e.g. - file_path_{feature}.pkl. Each feature will be saved to a file with - the feature name replaced in cache_pattern. If not None - feature arrays will be saved here and not stored in self.data until - load_cached_data is called. The cache_pattern can also include - {shape}, {target}, {times} which will help ensure unique cache - files for complex problems. - overwrite_cache : bool - Whether to overwrite any previously saved cache files. - overwrite_ti_cache : bool - Whether to overwrite any previously saved time index cache files. - overwrite_ti_cache : bool - Whether to overwrite saved time index cache files. - load_cached : bool - Whether to load data from cache files - train_only_features : list | tuple | None - List of feature names or patt*erns that should only be included in - the training set and not the output. If None (default), this will - default to the class TRAIN_ONLY_FEATURES attribute. - handle_features : list | None - Optional list of features which are available in the provided data. - Providing this eliminates the need for an initial search of - available features prior to data extraction. - single_ts_files : bool | None - Whether input files are single time steps or not. If they are this - enables some reduced computation. If None then this will be - determined from file_paths directly. - mask_nan : bool - Flag to mask out (remove) any timesteps with NaN data from the - source dataset. This is False by default because it can create - discontinuities in the timeseries. - worker_kwargs : dict | None - Dictionary of worker values. Can include max_workers, - extract_workers, compute_workers, load_workers, norm_workers, - and ti_workers. Each argument needs to be an integer or None. - - The value of `max workers` will set the value of all other worker - args. If max_workers == 1 then all processes will be serialized. If - max_workers == None then other worker args will use their own - provided values. - - `extract_workers` is the max number of workers to use for - extracting features from source data. If None it will be estimated - based on memory limits. If 1 processes will be serialized. - `compute_workers` is the max number of workers to use for computing - derived features from raw features in source data. `load_workers` - is the max number of workers to use for loading cached feature - data. `norm_workers` is the max number of workers to use for - normalizing feature data. `ti_workers` is the max number of - workers to use to get full time index. Useful when there are many - input files each with a single time step. If this is greater than - one, time indices for input files will be extracted in parallel - and then concatenated to get the full time index. If input files - do not all have time indices or if there are few input files this - should be set to one. - res_kwargs : dict | None - kwargs passed to source handler for data extraction. e.g. This - could be {'parallel': True, - 'concat_dim': 'Time', - 'combine': 'nested', - 'chunks': {'south_north': 120, 'west_east': 120}} - which then gets passed to xr.open_mfdataset(file, **res_kwargs) - """ - InputMixIn.__init__( - self, - target=target, - shape=shape, - raster_file=raster_file, - raster_index=raster_index, - temporal_slice=temporal_slice, - ) - - self.file_paths = file_paths - self.features = ( - features if isinstance(features, (list, tuple)) else [features] - ) - self.val_time_index = None - self.max_delta = max_delta - self.val_split = val_split - self.sample_shape = sample_shape - self.hr_spatial_coarsen = hr_spatial_coarsen or 1 - self.time_roll = time_roll - self.shuffle_time = shuffle_time - self.current_obs_index = None - self.overwrite_cache = overwrite_cache - self.overwrite_ti_cache = overwrite_ti_cache - self.load_cached = load_cached - self.data = None - self.val_data = None - self.res_kwargs = res_kwargs or {} - self._single_ts_files = single_ts_files - self._cache_pattern = cache_pattern - self._train_only_features = train_only_features - self._time_chunk_size = time_chunk_size - self._handle_features = handle_features - self._cache_files = None - self._extract_features = None - self._noncached_features = None - self._raw_features = None - self._raw_data = {} - self._time_chunks = None - worker_kwargs = worker_kwargs or {} - self.max_workers = worker_kwargs.get('max_workers', None) - self._ti_workers = worker_kwargs.get('ti_workers', None) - self._extract_workers = worker_kwargs.get('extract_workers', None) - self._norm_workers = worker_kwargs.get('norm_workers', None) - self._load_workers = worker_kwargs.get('load_workers', None) - self._compute_workers = worker_kwargs.get('compute_workers', None) - self._worker_attrs = [ - '_ti_workers', - '_norm_workers', - '_compute_workers', - '_extract_workers', - '_load_workers', - ] - - self.preflight() - - try_load = ( - cache_pattern is not None - and not self.overwrite_cache - and all(os.path.exists(fp) for fp in self.cache_files) - ) - - overwrite = ( - self.overwrite_cache - and self.cache_files is not None - and all(os.path.exists(fp) for fp in self.cache_files) - ) - - if try_load and self.load_cached: - logger.info( - f'All {self.cache_files} exist. Loading from cache ' - f'instead of extracting from source files.' - ) - self.load_cached_data() - - elif try_load and not self.load_cached: - self.clear_data() - logger.info( - f'All {self.cache_files} exist. Call ' - 'load_cached_data() or use load_cache=True to load ' - 'this data from cache files.' - ) - else: - if overwrite: - logger.info( - f'{self.cache_files} exists but overwrite_cache ' - 'is set to True. Proceeding with extraction.' - ) - - self._raster_size_check() - self._run_data_init_if_needed() - - if cache_pattern is not None: - self.cache_data(self.cache_files) - self.data = None if not self.load_cached else self.data - - self._val_split_check() - - if mask_nan: - nan_mask = np.isnan(self.data).any(axis=(0, 1, 3)) - logger.info( - 'Removing {} out of {} timesteps due to NaNs'.format( - nan_mask.sum(), self.data.shape[2] - ) - ) - self.data = self.data[:, :, ~nan_mask, :] - - logger.info('Finished intializing DataHandler.') - log_mem(logger, log_level='INFO') - - def _run_data_init_if_needed(self): - """Check if any features need to be extracted and proceed with data - extraction""" - if any(self.features): - self.data = self.run_all_data_init() - nan_perc = 100 * np.isnan(self.data).sum() / self.data.size - if nan_perc > 0: - msg = 'Data has {:.2f}% NaN values!'.format(nan_perc) - logger.warning(msg) - warnings.warn(msg) - - def _raster_size_check(self): - """Check if the sample_shape is larger than the requested raster - size""" - bad_shape = ( - self.sample_shape[0] > self.grid_shape[0] - and self.sample_shape[1] > self.grid_shape[1] - ) - if bad_shape: - msg = ( - f'spatial_sample_shape {self.sample_shape[:2]} is ' - f'larger than the raster size {self.grid_shape}' - ) - logger.warning(msg) - warnings.warn(msg) - - def _val_split_check(self): - """Check if val_split > 0 and split data into validation and training. - Make sure validation data is larger than sample_shape""" - - if self.data is not None and self.val_split > 0.0: - self.data, self.val_data = self.split_data() - msg = ( - f'Validation data has shape={self.val_data.shape} ' - f'and sample_shape={self.sample_shape}. Use a smaller ' - 'sample_shape and/or larger val_split.' - ) - check = any( - val_size < samp_size - for val_size, samp_size in zip( - self.val_data.shape, self.sample_shape - ) - ) - if check: - logger.warning(msg) - warnings.warn(msg) - - @classmethod - @abstractmethod - def get_full_domain(cls, file_paths): - """Get target and shape for full domain""" - - def clear_data(self): - """Free memory used for data arrays""" - self.data = None - self.val_data = None - - @classmethod - @abstractmethod - def source_handler(cls, file_paths, **kwargs): - """Handle for source data. Uses xarray, ResourceX, etc. - - NOTE: that xarray appears to treat open file handlers as singletons - within a threadpool, so its okay to open this source_handler without a - context handler or a .close() statement. - """ - - @property - def attrs(self): - """Get atttributes of input data - - Returns - ------- - dict - Dictionary of attributes - """ - handle = self.source_handler(self.file_paths) - desc = handle.attrs - return desc - - @property - def train_only_features(self): - """Features to use for training only and not output""" - if self._train_only_features is None: - self._train_only_features = self.TRAIN_ONLY_FEATURES - return self._train_only_features - - @property - def extract_workers(self): - """Get upper bound for extract workers based on memory limits. Used to - extract data from source dataset. The max number of extract workers - is number of time chunks * number of features""" - proc_mem = 4 * self.grid_mem * len(self.time_index) - proc_mem /= len(self.time_chunks) - n_procs = len(self.time_chunks) * len(self.extract_features) - n_procs = int(np.ceil(n_procs)) - extract_workers = estimate_max_workers( - self._extract_workers, proc_mem, n_procs - ) - return extract_workers - - @property - def compute_workers(self): - """Get upper bound for compute workers based on memory limits. Used to - compute derived features from source dataset.""" - proc_mem = int( - np.ceil( - len(self.extract_features) - / np.maximum(len(self.derive_features), 1) - ) - ) - proc_mem *= 4 * self.grid_mem * len(self.time_index) - proc_mem /= len(self.time_chunks) - n_procs = len(self.time_chunks) * len(self.derive_features) - n_procs = int(np.ceil(n_procs)) - compute_workers = estimate_max_workers( - self._compute_workers, proc_mem, n_procs - ) - return compute_workers - - @property - def load_workers(self): - """Get upper bound on load workers based on memory limits. Used to load - cached data.""" - proc_mem = 2 * self.feature_mem - n_procs = 1 - if self.cache_files is not None: - n_procs = len(self.cache_files) - load_workers = estimate_max_workers( - self._load_workers, proc_mem, n_procs - ) - return load_workers - - @property - def norm_workers(self): - """Get upper bound on workers used for normalization.""" - if self.data is not None: - norm_workers = estimate_max_workers( - self._norm_workers, 2 * self.feature_mem, self.shape[-1] - ) - else: - norm_workers = self._norm_workers - return norm_workers - - @property - def time_chunks(self): - """Get time chunks which will be extracted from source data - - Returns - ------- - _time_chunks : list - List of time chunks used to split up source data time dimension - so that each chunk can be extracted individually - """ - if self._time_chunks is None: - if self.is_time_independent: - self._time_chunks = [slice(None)] - else: - self._time_chunks = get_chunk_slices( - len(self.raw_time_index), - self.time_chunk_size, - self.temporal_slice, - ) - return self._time_chunks - - @property - def is_time_independent(self): - """Get whether source data files are time independent""" - return self.raw_time_index[0] is None - - @property - def n_tsteps(self): - """Get number of time steps to extract""" - if self.is_time_independent: - return 1 - else: - return len(self.raw_time_index[self.temporal_slice]) - - @property - def time_chunk_size(self): - """Get upper bound on time chunk size based on memory limits""" - if self._time_chunk_size is None: - step_mem = self.feature_mem * len(self.extract_features) - step_mem /= len(self.time_index) - if step_mem == 0: - self._time_chunk_size = self.n_tsteps - else: - self._time_chunk_size = np.min( - [int(1e9 / step_mem), self.n_tsteps] - ) - logger.info( - 'time_chunk_size arg not specified. Using ' - f'{self._time_chunk_size}.' - ) - return self._time_chunk_size - - @property - def cache_files(self): - """Cache files for storing extracted data""" - if self._cache_files is None: - self._cache_files = self.get_cache_file_names(self.cache_pattern) - return self._cache_files - - @property - def raster_index(self): - """Raster index property""" - if self._raster_index is None: - self._raster_index = self.get_raster_index() - return self._raster_index - - @raster_index.setter - def raster_index(self, raster_index): - """Update raster index property""" - self._raster_index = raster_index - - @classmethod - def get_handle_features(cls, file_paths): - """Get all available features in input data - - Parameters - ---------- - file_paths : list - List of input file paths - - Returns - ------- - handle_features : list - List of available input features - """ - handle_features = [] - for f in file_paths: - handle = cls.source_handler([f]) - for r in handle: - handle_features.append(Feature.get_basename(r)) - return list(set(handle_features)) - - @property - def handle_features(self): - """All features available in raw input""" - if self._handle_features is None: - self._handle_features = self.get_handle_features(self.file_paths) - return self._handle_features - - @property - def noncached_features(self): - """Get list of features needing extraction or derivation""" - if self._noncached_features is None: - self._noncached_features = self.check_cached_features( - self.features, - cache_files=self.cache_files, - overwrite_cache=self.overwrite_cache, - load_cached=self.load_cached, - ) - return self._noncached_features - - @property - def extract_features(self): - """Features to extract directly from the source handler""" - lower_features = [f.lower() for f in self.handle_features] - return [ - f - for f in self.raw_features - if self.lookup(f, 'compute') is None - or Feature.get_basename(f.lower()) in lower_features - ] - - @property - def derive_features(self): - """List of features which need to be derived from other features""" - derive_features = [ - f - for f in set( - list(self.noncached_features) + list(self.extract_features) - ) - if f not in self.extract_features - ] - return derive_features - - @property - def cached_features(self): - """List of features which have been requested but have been determined - not to need extraction. Thus they have been cached already.""" - return [f for f in self.features if f not in self.noncached_features] - - @property - def raw_features(self): - """Get list of features needed for computations""" - if self._raw_features is None: - self._raw_features = self.get_raw_feature_list( - self.noncached_features, self.handle_features - ) - return self._raw_features - - @property - def output_features(self): - """Get a list of features that should be output by the generative model - corresponding to the features in the high res batch array.""" - out = [] - for feature in self.features: - ignore = any( - fnmatch(feature.lower(), pattern.lower()) - for pattern in self.train_only_features - ) - if not ignore: - out.append(feature) - return out - - @property - def grid_mem(self): - """Get memory used by a feature at a single time step - - Returns - ------- - int - Number of bytes for a single feature array at a single time step - """ - grid_mem = np.product(self.grid_shape) - # assuming feature arrays are float32 (4 bytes) - return 4 * grid_mem - - @property - def feature_mem(self): - """Number of bytes for a single feature array. Used to estimate - max_workers. - - Returns - ------- - int - Number of bytes for a single feature array - """ - feature_mem = self.grid_mem * len(self.time_index) - return feature_mem - - def preflight(self): - """Run some preflight checks and verify that the inputs are valid""" - - self.cap_worker_args(self.max_workers) - - if len(self.sample_shape) == 2: - logger.info( - 'Found 2D sample shape of {}. Adding temporal dim of 1'.format( - self.sample_shape - ) - ) - self.sample_shape = (*self.sample_shape, 1) - - start = self.temporal_slice.start - stop = self.temporal_slice.stop - n_steps = self.n_tsteps - msg = ( - f'Temporal slice step ({self.temporal_slice.step}) does not ' - f'evenly divide the number of time steps ({n_steps})' - ) - check = self.temporal_slice.step is None - check = check or n_steps % self.temporal_slice.step == 0 - if not check: - logger.warning(msg) - warnings.warn(msg) - - msg = ( - f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger ' - 'than the number of time steps in the raw data ' - f'({len(self.raw_time_index)}).' - ) - if len(self.raw_time_index) < self.sample_shape[2]: - logger.warning(msg) - warnings.warn(msg) - - msg = ( - f'The requested time slice {self.temporal_slice} conflicts ' - f'with the number of time steps ({len(self.raw_time_index)}) ' - 'in the raw data' - ) - t_slice_is_subset = start is not None and stop is not None - good_subset = ( - t_slice_is_subset - and (stop - start <= len(self.raw_time_index)) - and stop <= len(self.raw_time_index) - and start <= len(self.raw_time_index) - ) - if t_slice_is_subset and not good_subset: - logger.error(msg) - raise RuntimeError(msg) - - msg = ( - f'Initializing DataHandler {self.input_file_info}. ' - f'Getting temporal range {self.time_index[0]!s} to ' - f'{self.time_index[-1]!s} (inclusive) ' - f'based on temporal_slice {self.temporal_slice}' - ) - logger.info(msg) - - logger.info( - f'Using max_workers={self.max_workers}, ' - f'norm_workers={self.norm_workers}, ' - f'extract_workers={self.extract_workers}, ' - f'compute_workers={self.compute_workers}, ' - f'load_workers={self.load_workers}, ' - f'ti_workers={self.ti_workers}' - ) - - @classmethod - def get_lat_lon(cls, file_paths, raster_index, invert_lat=False): - """Get lat/lon grid for requested target and shape - - Parameters - ---------- - file_paths : list - path to data file - raster_index : ndarray | list - Raster index array or list of slices - invert_lat : bool - Flag to invert data along the latitude axis. Wrf data tends to use - an increasing ordering for latitude while wtk uses a decreasing - ordering. - - Returns - ------- - ndarray - (spatial_1, spatial_2, 2) Lat/Lon array with same ordering in last - dimension - """ - lat_lon = cls.lookup('lat_lon', 'compute')(file_paths, raster_index) - if invert_lat: - lat_lon = lat_lon[::-1] - # put angle betwen -180 and 180 - lat_lon[..., 1] = (lat_lon[..., 1] + 180) % 360 - 180 - return lat_lon - - @classmethod - def get_node_cmd(cls, config): - """Get a CLI call to initialize DataHandler and cache data. - - Parameters - ---------- - config : dict - sup3r data handler config with all necessary args and kwargs to - initialize DataHandler and run data extraction. - """ - - import_str = ( - 'from sup3r.preprocessing.data_handling ' - f'import {cls.__name__};\n' - 'import time;\n' - 'from reV.pipeline.status import Status;\n' - 'from rex import init_logger;\n' - ) - - dh_init_str = get_fun_call_str(cls, config) - - log_file = config.get('log_file', None) - log_level = config.get('log_level', 'INFO') - log_arg_str = f'"sup3r", log_level="{log_level}"' - if log_file is not None: - log_arg_str += f', log_file="{log_file}"' - - cache_check = config.get('cache_pattern', False) - - msg = 'No cache file prefix provided.' - if not cache_check: - logger.warning(msg) - warnings.warn(msg) - - cmd = ( - f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"data_handler = {dh_init_str};\n" - "t_elap = time.time() - t0;\n" - ) - - cmd = BaseCLI.add_status_cmd(config, ModuleName.DATA_EXTRACT, cmd) - - cmd += ";\'\n" - return cmd.replace('\\', '/') - - def get_cache_file_names(self, cache_pattern): - """Get names of cache files from cache_pattern and feature names - - Parameters - ---------- - cache_pattern : str - Pattern to use for cache file names - - Returns - ------- - list - List of cache file names - """ - if cache_pattern is not None: - cache_files = [ - cache_pattern.replace('{feature}', f.lower()) - for f in self.features - ] - for i, f in enumerate(cache_files): - if '{shape}' in f: - shape = f'{self.grid_shape[0]}x{self.grid_shape[1]}' - shape += f'x{len(self.time_index)}' - f = f.replace('{shape}', shape) - if '{target}' in f: - target = f'{self.target[0]:.2f}_{self.target[1]:.2f}' - f = f.replace('{target}', target) - if '{times}' in f: - times = f'{self.timestamp_0}_{self.timestamp_1}' - f = f.replace('{times}', times) - - cache_files[i] = f - - for i, fp in enumerate(cache_files): - fp_check = ignore_case_path_fetch(fp) - if fp_check is not None: - cache_files[i] = fp_check - else: - cache_files = None - - return cache_files - - def unnormalize(self, means, stds): - """Remove normalization from stored means and stds""" - self.val_data = (self.val_data * stds) + means - self.data = (self.data * stds) + means - - def normalize(self, means, stds): - """Normalize all data features - - Parameters - ---------- - means : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - stds : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - """ - logger.info(f'Normalizing {self.shape[-1]} features.') - max_workers = self.norm_workers - if max_workers == 1: - for i in range(self.shape[-1]): - self._normalize_data(i, means[i], stds[i]) - else: - self.parallel_normalization(means, stds, max_workers=max_workers) - - def parallel_normalization(self, means, stds, max_workers=None): - """Run normalization of features in parallel - - Parameters - ---------- - means : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - stds : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - max_workers : int | None - Max number of workers to use for normalizing features - """ - - with ThreadPoolExecutor(max_workers=max_workers) as exe: - futures = {} - now = dt.now() - for i in range(self.shape[-1]): - future = exe.submit(self._normalize_data, i, means[i], stds[i]) - futures[future] = i - - logger.info( - f'Started normalizing {self.shape[-1]} features ' - f'in {dt.now() - now}.' - ) - - for i, future in enumerate(as_completed(futures)): - try: - future.result() - except Exception as e: - msg = ( - 'Error while normalizing future number ' - f'{futures[future]}.' - ) - logger.exception(msg) - raise RuntimeError(msg) from e - logger.debug( - f'{i+1} out of {self.shape[-1]} features ' 'normalized.' - ) - - def _normalize_data(self, feature_index, mean, std): - """Normalize data with initialized mean and standard deviation for a - specific feature - - Parameters - ---------- - feature_index : int - index of feature to be normalized - mean : float32 - specified mean of associated feature - std : float32 - specificed standard deviation for associated feature - """ - - if self.val_data is not None: - self.val_data[..., feature_index] -= mean - self.data[..., feature_index] -= mean - - if std > 0: - if self.val_data is not None: - self.val_data[..., feature_index] /= std - self.data[..., feature_index] /= std - else: - msg = ( - 'Standard Deviation is zero for ' - f'{self.features[feature_index]}' - ) - logger.warning(msg) - warnings.warn(msg) - - def get_observation_index(self): - """Randomly gets spatial sample and time sample - - Returns - ------- - observation_index : tuple - Tuple of sampled spatial grid, time slice, and features indices. - Used to get single observation like self.data[observation_index] - """ - spatial_slice = uniform_box_sampler(self.data, self.sample_shape[:2]) - temporal_slice = uniform_time_sampler(self.data, self.sample_shape[2]) - return tuple( - [*spatial_slice, temporal_slice, np.arange(len(self.features))] - ) - - def get_next(self): - """Get data for observation using random observation index. Loops - repeatedly over randomized time index - - Returns - ------- - observation : np.ndarray - 4D array - (spatial_1, spatial_2, temporal, features) - """ - self.current_obs_index = self.get_observation_index() - observation = self.data[self.current_obs_index] - return observation - - def split_data(self, data=None): - """Split time dimension into set of training indices and validation - indices - - Parameters - ---------- - data : np.ndarray - 4D array of high res data - (spatial_1, spatial_2, temporal, features) - - Returns - ------- - data : np.ndarray - (spatial_1, spatial_2, temporal, features) - Training data fraction of initial data array. Initial data array is - overwritten by this new data array. - val_data : np.ndarray - (spatial_1, spatial_2, temporal, features) - Validation data fraction of initial data array. - """ - - if data is not None: - self.data = data - - n_observations = self.data.shape[2] - all_indices = np.arange(n_observations) - n_val_obs = int(self.val_split * n_observations) - - if self.shuffle_time: - np.random.shuffle(all_indices) - - val_indices = all_indices[:n_val_obs] - training_indices = all_indices[n_val_obs:] - - if not self.shuffle_time: - [self.val_data, self.data] = np.split( - self.data, [n_val_obs], axis=2 - ) - else: - self.val_data = self.data[:, :, val_indices, :] - self.data = self.data[:, :, training_indices, :] - - self.val_time_index = self.time_index[val_indices] - self.time_index = self.time_index[training_indices] - - return self.data, self.val_data - - @property - def shape(self): - """Full data shape - - Returns - ------- - shape : tuple - Full data shape - (spatial_1, spatial_2, temporal, features) - """ - return self.data.shape - - def cache_data(self, cache_file_paths): - """Cache feature data to file and delete from memory - - Parameters - ---------- - cache_file_paths : str | None - Path to file for saving feature data - """ - - for i, fp in enumerate(cache_file_paths): - if not os.path.exists(fp) or self.overwrite_cache: - if self.overwrite_cache and os.path.exists(fp): - logger.info( - f'Overwriting {self.features[i]} with shape ' - f'{self.data[..., i].shape} to {fp}' - ) - else: - logger.info( - f'Saving {self.features[i]} with shape ' - f'{self.data[..., i].shape} to {fp}' - ) - - tmp_file = fp.replace('.pkl', '.pkl.tmp') - with open(tmp_file, 'wb') as fh: - pickle.dump(self.data[..., i], fh, protocol=4) - os.replace(tmp_file, fp) - else: - msg = ( - f'Called cache_data but {fp} already exists. Set to ' - 'overwrite_cache to True to overwrite.' - ) - logger.warning(msg) - warnings.warn(msg) - - def parallel_load(self, max_workers=None): - """Load feature data in parallel - - Parameters - ---------- - max_workers : int | None - Max number of workers to use for parallel data loading. If None - the max number of available workers will be used. - """ - logger.info(f'Loading {len(self.cache_files)} cache files.') - with ThreadPoolExecutor(max_workers=max_workers) as exe: - futures = {} - now = dt.now() - for i, fp in enumerate(self.cache_files): - future = exe.submit(self.load_single_cached_feature, fp=fp) - futures[future] = {'idx': i, 'fp': os.path.basename(fp)} - - logger.info( - f'Started loading all {len(self.cache_files)} cache ' - f'files in {dt.now() - now}.' - ) - - for i, future in enumerate(as_completed(futures)): - try: - future.result() - except Exception as e: - msg = ( - 'Error while loading ' - f'{self.cache_files[futures[future]["idx"]]}' - ) - logger.exception(msg) - raise RuntimeError(msg) from e - logger.debug( - f'{i+1} out of {len(futures)} cache files ' - f'loaded: {futures[future]["fp"]}' - ) - - def load_single_cached_feature(self, fp): - """Load single feature from given file - - Parameters - ---------- - fp : string - File path for feature cache file - - Raises - ------ - RuntimeError - Error raised if shape conflicts with requested shape - """ - idx = self.cache_files.index(fp) - assert self.features[idx].lower() in fp.lower() - fp = ignore_case_path_fetch(fp) - logger.info( - f'Loading {self.features[idx]} from ' f'{os.path.basename(fp)}' - ) - - with open(fp, 'rb') as fh: - try: - self.data[..., idx] = np.array( - pickle.load(fh), dtype=np.float32 - ) - except Exception as e: - msg = ( - 'Data loaded from from cache file "{}" ' - 'could not be written to feature channel {} ' - 'of full data array of shape {}. ' - 'The cached data has the wrong shape {}.'.format( - fp, idx, self.data.shape, pickle.load(fh).shape - ) - ) - raise RuntimeError(msg) from e - - def load_cached_data(self): - """Load data from cache files and split into training and validation""" - if self.data is not None: - logger.info('Called load_cached_data() but self.data is not None') - - elif self.data is None: - shape = get_raster_shape(self.raster_index) - requested_shape = ( - shape[0] // self.hr_spatial_coarsen, - shape[1] // self.hr_spatial_coarsen, - len(self.time_index), - len(self.features), - ) - - msg = ( - 'Found {} cache files but need {} for features {}! ' - 'These are the cache files that were found: {}'.format( - len(self.cache_files), - len(self.features), - self.features, - self.cache_files, - ) - ) - assert len(self.cache_files) == len(self.features), msg - - self.data = np.full( - shape=requested_shape, fill_value=np.nan, dtype=np.float32 - ) - - logger.info(f'Loading cached data from: {self.cache_files}') - max_workers = self.load_workers - if max_workers == 1: - for _, fp in enumerate(self.cache_files): - self.load_single_cached_feature(fp) - else: - self.parallel_load(max_workers=max_workers) - - nan_perc = 100 * np.isnan(self.data).sum() / self.data.size - if nan_perc > 0: - msg = 'Data has {:.2f}% NaN values!'.format(nan_perc) - logger.warning(msg) - warnings.warn(msg) - - logger.debug( - 'Splitting data into training / validation sets ' - f'({1 - self.val_split}, {self.val_split}) ' - f'for {self.input_file_info}' - ) - self.data, self.val_data = self.split_data() - - @classmethod - def check_cached_features( - cls, - features, - cache_files=None, - overwrite_cache=False, - load_cached=False, - ): - """Check which features have been cached and check flags to determine - whether to load or extract this features again - - Parameters - ---------- - features : list - list of features to extract - cache_files : list | None - Path to files with saved feature data - overwrite_cache : bool - Whether to overwrite cached files - load_cached : bool - Whether to load data from cache files - - Returns - ------- - list - List of features to extract. Might not include features which have - cache files. - """ - extract_features = [] - # check if any features can be loaded from cache - if cache_files is not None: - for i, f in enumerate(features): - check = ( - os.path.exists(cache_files[i]) - and f.lower() in cache_files[i].lower() - ) - if check: - if not overwrite_cache: - if load_cached: - msg = ( - f'{f} found in cache file {cache_files[i]}.' - ' Loading from cache instead of extracting ' - 'from source files' - ) - logger.info(msg) - else: - msg = ( - f'{f} found in cache file {cache_files[i]}.' - ' Call load_cached_data() or use ' - 'load_cached=True to load this data.' - ) - logger.info(msg) - else: - msg = ( - f'{cache_files[i]} exists but overwrite_cache ' - 'is set to True. Proceeding with extraction.' - ) - logger.info(msg) - extract_features.append(f) - else: - extract_features.append(f) - else: - extract_features = features - - return extract_features - - def run_all_data_init(self): - """Build base 4D data array. Can handle multiple files but assumes - each file has the same spatial domain - - Returns - ------- - data : np.ndarray - 4D array of high res data - (spatial_1, spatial_2, temporal, features) - """ - now = dt.now() - logger.debug(f'Loading data for raster of shape {self.grid_shape}') - - # get the file-native time index without pruning - if self.is_time_independent: - n_steps = 1 - shifted_time_chunks = [slice(None)] - else: - n_steps = len(self.raw_time_index[self.temporal_slice]) - shifted_time_chunks = get_chunk_slices( - n_steps, self.time_chunk_size - ) - - self.run_data_extraction() - self.run_data_compute() - - logger.info('Building final data array') - self.parallel_data_fill(shifted_time_chunks, self.extract_workers) - - if self.invert_lat: - self.data = self.data[::-1] - - if self.time_roll != 0: - logger.debug('Applying time roll to data array') - self.data = np.roll(self.data, self.time_roll, axis=2) - - if self.hr_spatial_coarsen > 1: - logger.debug('Applying hr spatial coarsening to data array') - self.data = spatial_coarsening( - self.data, s_enhance=self.hr_spatial_coarsen, obs_axis=False - ) - if self.load_cached: - for f in self.cached_features: - f_index = self.features.index(f) - logger.info(f'Loading {f} from {self.cache_files[f_index]}') - with open(self.cache_files[f_index], 'rb') as fh: - self.data[..., f_index] = pickle.load(fh) - - logger.info( - 'Finished extracting data for ' - f'{self.input_file_info} in ' - f'{dt.now() - now}' - ) - return self.data - - def run_data_extraction(self): - """Run the raw dataset extraction process from disk to raw - un-manipulated datasets. - """ - if self.extract_features: - logger.info( - f'Starting extraction of {self.extract_features} ' - f'using {len(self.time_chunks)} time_chunks.' - ) - if self.extract_workers == 1: - self._raw_data = self.serial_extract( - self.file_paths, - self.raster_index, - self.time_chunks, - self.extract_features, - **self.res_kwargs, - ) - - else: - self._raw_data = self.parallel_extract( - self.file_paths, - self.raster_index, - self.time_chunks, - self.extract_features, - self.extract_workers, - **self.res_kwargs, - ) - - logger.info( - f'Finished extracting {self.extract_features} for ' - f'{self.input_file_info}' - ) - - def run_data_compute(self): - """Run the data computation / derivation from raw features to desired - features. - """ - if self.derive_features: - logger.info(f'Starting computation of {self.derive_features}') - - if self.compute_workers == 1: - self._raw_data = self.serial_compute( - self._raw_data, - self.file_paths, - self.raster_index, - self.time_chunks, - self.derive_features, - self.noncached_features, - self.handle_features, - ) - - elif self.compute_workers != 1: - self._raw_data = self.parallel_compute( - self._raw_data, - self.file_paths, - self.raster_index, - self.time_chunks, - self.derive_features, - self.noncached_features, - self.handle_features, - self.compute_workers, - ) - - logger.info( - f'Finished computing {self.derive_features} for ' - f'{self.input_file_info}' - ) - - def data_fill(self, t, t_slice, f_index, f): - """Place single extracted / computed chunk in final data array - - Parameters - ---------- - t : int - Index of time slice in extracted / computed raw data dictionary - t_slice : slice - Time slice corresponding to the location in the final data array - f_index : int - Index of feature in the final data array - f : str - Name of corresponding feature in the raw data dictionary - """ - tmp = self._raw_data[t][f] - if len(tmp.shape) == 2: - tmp = tmp[..., np.newaxis] - self.data[..., t_slice, f_index] = tmp - - def serial_data_fill(self, shifted_time_chunks): - """Fill final data array in serial - - Parameters - ---------- - shifted_time_chunks : list - List of time slices corresponding to the appropriate location of - extracted / computed chunks in the final data array - """ - for t, ts in enumerate(shifted_time_chunks): - for _, f in enumerate(self.noncached_features): - f_index = self.features.index(f) - self.data_fill(t, ts, f_index, f) - interval = int(np.ceil(len(shifted_time_chunks) / 10)) - if t % interval == 0: - logger.info( - f'Added {t + 1} of {len(shifted_time_chunks)} ' - 'chunks to final data array' - ) - self._raw_data.pop(t) - - def parallel_data_fill(self, shifted_time_chunks, max_workers=None): - """Fill final data array with extracted / computed chunks - - Parameters - ---------- - shifted_time_chunks : list - List of time slices corresponding to the appropriate location of - extracted / computed chunks in the final data array - max_workers : int | None - Max number of workers to use for building final data array. If None - max available workers will be used. If 1 cached data will be loaded - in serial - """ - self.data = np.zeros( - ( - self.grid_shape[0], - self.grid_shape[1], - self.n_tsteps, - len(self.features), - ), - dtype=np.float32, - ) - - if max_workers == 1: - self.serial_data_fill(shifted_time_chunks) - else: - with ThreadPoolExecutor(max_workers=max_workers) as exe: - futures = {} - now = dt.now() - for t, ts in enumerate(shifted_time_chunks): - for _, f in enumerate(self.noncached_features): - f_index = self.features.index(f) - future = exe.submit(self.data_fill, t, ts, f_index, f) - futures[future] = {'t': t, 'fidx': f_index} - - logger.info( - f'Started adding {len(futures)} chunks ' - f'to data array in {dt.now() - now}.' - ) - - interval = int(np.ceil(len(futures) / 10)) - for i, future in enumerate(as_completed(futures)): - try: - future.result() - except Exception as e: - msg = ( - f'Error adding ({futures[future]["t"]}, ' - f'{futures[future]["fidx"]}) chunk to ' - 'final data array.' - ) - logger.exception(msg) - raise RuntimeError(msg) from e - if i % interval == 0: - logger.debug( - f'Added {i+1} out of {len(futures)} ' - 'chunks to final data array' - ) - logger.info('Finished building data array') - - @abstractmethod - def get_raster_index(self): - """Get raster index for file data. Here we assume the list of paths in - file_paths all have data with the same spatial domain. We use the first - file in the list to compute the raster - - Returns - ------- - raster_index : np.ndarray - 2D array of grid indices for H5 or list of - slices for NETCDF - """ - - def lin_bc(self, bc_files, threshold=0.1): - """Bias correct the data in this DataHandler using linear bias - correction factors from files output by MonthlyLinearCorrection or - LinearCorrection from sup3r.bias.bias_calc - - Parameters - ---------- - bc_files : list | tuple | str - One or more filepaths to .h5 files output by - MonthlyLinearCorrection or LinearCorrection. These should contain - datasets named "{feature}_scalar" and "{feature}_adder" where - {feature} is one of the features contained by this DataHandler and - the data is a 3D array of shape (lat, lon, time) where time is - length 1 for annual correction or 12 for monthly correction. - threshold : float - Nearest neighbor euclidean distance threshold. If the DataHandler - coordinates are more than this value away from the bias correction - lat/lon, an error is raised. - """ - - if isinstance(bc_files, str): - bc_files = [bc_files] - - completed = [] - for idf, feature in enumerate(self.features): - for fp in bc_files: - if feature not in completed: - scalar, adder = get_spatial_bc_factors( - lat_lon=self.lat_lon, - feature_name=feature, - bias_fp=fp, - threshold=threshold, - ) - - if scalar.shape[-1] == 1: - scalar = np.repeat(scalar, self.shape[2], axis=2) - adder = np.repeat(adder, self.shape[2], axis=2) - elif scalar.shape[-1] == 12: - idm = self.time_index.month.values - 1 - scalar = scalar[..., idm] - adder = adder[..., idm] - else: - msg = ( - 'Can only accept bias correction factors ' - 'with last dim equal to 1 or 12 but ' - 'received bias correction factors with ' - 'shape {}'.format(scalar.shape) - ) - logger.error(msg) - raise RuntimeError(msg) - - logger.info( - 'Bias correcting "{}" with linear ' - 'correction from "{}"'.format( - feature, os.path.basename(fp) - ) - ) - self.data[..., idf] *= scalar - self.data[..., idf] += adder - completed.append(feature) - - -class DataHandlerNC(DataHandler): - """Data Handler for NETCDF data""" - - CHUNKS: ClassVar[dict] = { - 'XTIME': 100, - 'XLAT': 150, - 'XLON': 150, - 'south_north': 150, - 'west_east': 150, - 'Time': 100, - } - """CHUNKS sets the chunk sizes to extract from the data in each dimension. - Chunk sizes that approximately match the data volume being extracted - typically results in the most efficient IO.""" - - def __init__(self, *args, xr_chunks=None, **kwargs): - """Initialize NETCDF data handler. - - Parameters - ---------- - *args : list - Same ordered required arguments as DataHandler parent class. - xr_chunks : int | "auto" | tuple | dict | None - kwarg that goes to xr.DataArray.chunk(chunks=xr_chunks). Chunk - sizes that approximately match the data volume being extracted - typically results in the most efficient IO. If not provided, this - defaults to the class CHUNKS attribute. - **kwargs : list - Same optional keyword arguments as DataHandler parent class. - """ - if xr_chunks is not None: - self.CHUNKS = xr_chunks - - super().__init__(*args, **kwargs) - - @property - def extract_workers(self): - """Get upper bound for extract workers based on memory limits. Used to - extract data from source dataset""" - # This large multiplier is due to the height interpolation allocating - # multiple arrays with up to 60 vertical levels - proc_mem = 6 * 64 * self.grid_mem * len(self.time_index) - proc_mem /= len(self.time_chunks) - n_procs = len(self.time_chunks) * len(self.extract_features) - n_procs = int(np.ceil(n_procs)) - extract_workers = estimate_max_workers( - self._extract_workers, proc_mem, n_procs - ) - return extract_workers - - @classmethod - def source_handler(cls, file_paths, **kwargs): - """Xarray data handler - - Note that xarray appears to treat open file handlers as singletons - within a threadpool, so its okay to open this source_handler without a - context handler or a .close() statement. - - Parameters - ---------- - file_paths : str | list - paths to data files - kwargs : dict - kwargs passed to source handler for data extraction. e.g. This - could be {'parallel': True, - 'chunks': {'south_north': 120, 'west_east': 120}} - which then gets passed to xr.open_mfdataset(file, **kwargs) - - Returns - ------- - data : xarray.Dataset - """ - time_key = get_time_dim_name(file_paths[0]) - default_kws = { - 'combine': 'nested', - 'concat_dim': time_key, - 'chunks': cls.CHUNKS, - } - default_kws.update(kwargs) - return xr.open_mfdataset(file_paths, **default_kws) - - @classmethod - def get_file_times(cls, file_paths, **kwargs): - """Get time index from data files - - Parameters - ---------- - file_paths : list - path to data file - kwargs : dict - kwargs passed to source handler for data extraction. e.g. This - could be {'parallel': True, - 'chunks': {'south_north': 120, 'west_east': 120}} - which then gets passed to xr.open_mfdataset(file, **kwargs) - - Returns - ------- - time_index : pd.Datetimeindex - List of times as a Datetimeindex - """ - handle = cls.source_handler(file_paths, **kwargs) - - if hasattr(handle, 'Times'): - time_index = np_to_pd_times(handle.Times.values) - elif hasattr(handle, 'indexes') and 'time' in handle.indexes: - time_index = handle.indexes['time'] - if not isinstance(time_index, pd.DatetimeIndex): - time_index = time_index.to_datetimeindex() - elif hasattr(handle, 'times'): - time_index = np_to_pd_times(handle.times.values) - else: - msg = ( - f'Could not get time_index for {file_paths}. ' - 'Assuming time independence.' - ) - time_index = None - logger.warning(msg) - warnings.warn(msg) - - return time_index - - @classmethod - def get_time_index(cls, file_paths, max_workers=None, **kwargs): - """Get time index from data files - - Parameters - ---------- - file_paths : list - path to data file - max_workers : int | None - Max number of workers to use for parallel time index building - kwargs : dict - kwargs passed to source handler for data extraction. e.g. This - could be {'parallel': True, - 'chunks': {'south_north': 120, 'west_east': 120}} - which then gets passed to xr.open_mfdataset(file, **kwargs) - - Returns - ------- - time_index : pd.Datetimeindex - List of times as a Datetimeindex - """ - max_workers = ( - len(file_paths) - if max_workers is None - else np.min((max_workers, len(file_paths))) - ) - if max_workers == 1: - return cls.get_file_times(file_paths, **kwargs) - ti = {} - with ThreadPoolExecutor(max_workers=max_workers) as exe: - futures = {} - now = dt.now() - for i, f in enumerate(file_paths): - future = exe.submit(cls.get_file_times, [f], **kwargs) - futures[future] = {'idx': i, 'file': f} - - logger.info( - f'Started building time index from {len(file_paths)} ' - f'files in {dt.now() - now}.' - ) - - for i, future in enumerate(as_completed(futures)): - try: - val = future.result() - if val is not None: - ti[futures[future]['idx']] = list(val) - except Exception as e: - msg = ( - 'Error while getting time index from file ' - f'{futures[future]["file"]}.' - ) - logger.exception(msg) - raise RuntimeError(msg) from e - logger.debug(f'Stored {i+1} out of {len(futures)} file times') - times = np.concatenate(list(ti.values())) - return pd.DatetimeIndex(sorted(set(times))) - - @classmethod - def feature_registry(cls): - """Registry of methods for computing features - - Returns - ------- - dict - Method registry - """ - registry = { - 'BVF2_(.*)m': BVFreqSquaredNC, - 'BVF_MO_(.*)m': BVFreqMon, - 'RMOL': InverseMonNC, - 'U_(.*)': UWind, - 'V_(.*)': VWind, - 'Windspeed_(.*)m': WindspeedNC, - 'Winddirection_(.*)m': WinddirectionNC, - 'lat_lon': LatLonNC, - 'Shear_(.*)m': Shear, - 'REWS_(.*)m': Rews, - 'Temperature_(.*)m': TempNC, - 'Pressure_(.*)m': PressureNC, - 'PotentialTemp_(.*)m': PotentialTempNC, - 'PT_(.*)m': PotentialTempNC, - 'topography': 'HGT', - } - return registry - - @classmethod - def extract_feature( - cls, - file_paths, - raster_index, - feature, - time_slice=slice(None), - **kwargs, - ): - """Extract single feature from data source. The requested feature - can match exactly to one found in the source data or can have a - matching prefix with a suffix specifying the height or pressure level - to interpolate to. e.g. feature=U_100m -> interpolate exact match U to - 100 meters. - - Parameters - ---------- - file_paths : list - path to data file - raster_index : ndarray - Raster index array - feature : str - Feature to extract from data - time_slice : slice - slice of time to extract - kwargs : dict - kwargs passed to source handler for data extraction. e.g. This - could be {'parallel': True, - 'chunks': {'south_north': 120, 'west_east': 120}} - which then gets passed to xr.open_mfdataset(file, **kwargs) - - Returns - ------- - ndarray - Data array for extracted feature - (spatial_1, spatial_2, temporal) - """ - logger.debug( - f'Extracting {feature} with time_slice={time_slice}, ' - f'raster_index={raster_index}, kwargs={kwargs}.' - ) - handle = cls.source_handler(file_paths, **kwargs) - f_info = Feature(feature, handle) - interp_height = f_info.height - interp_pressure = f_info.pressure - basename = f_info.basename - - if feature in handle or feature.lower() in handle: - feat_key = feature if feature in handle else feature.lower() - fdata = cls.direct_extract( - handle, feat_key, raster_index, time_slice - ) - - elif basename in handle or basename.lower() in handle: - feat_key = basename if basename in handle else basename.lower() - if interp_height is not None: - fdata = Interpolator.interp_var_to_height( - handle, - feat_key, - raster_index, - np.float32(interp_height), - time_slice, - ) - elif interp_pressure is not None: - fdata = Interpolator.interp_var_to_pressure( - handle, - feat_key, - raster_index, - np.float32(interp_pressure), - time_slice, - ) - - else: - msg = f'{feature} cannot be extracted from source data.' - logger.exception(msg) - raise ValueError(msg) - - fdata = np.transpose(fdata, (1, 2, 0)) - return fdata.astype(np.float32) - - @classmethod - def direct_extract(cls, handle, feature, raster_index, time_slice): - """Extract requested feature directly from source data, rather than - interpolating to a requested height or pressure level - - Parameters - ---------- - handle : xarray - netcdf data object - feature : str - Name of feature to extract directly from source handler - raster_index : list - List of slices for raster index of spatial domain - time_slice : slice - slice of time to extract - - Returns - ------- - fdata : ndarray - Data array for requested feature - """ - # Sometimes xarray returns fields with (Times, time, lats, lons) - # with a single entry in the 'time' dimension so we include this [0] - if len(handle[feature].dims) == 4: - idx = tuple([time_slice, 0, *raster_index]) - elif len(handle[feature].dims) == 3: - idx = tuple([time_slice, *raster_index]) - else: - idx = tuple(raster_index) - fdata = np.array(handle[feature][idx], dtype=np.float32) - if len(fdata.shape) == 2: - fdata = np.expand_dims(fdata, axis=0) - return fdata - - @classmethod - def get_full_domain(cls, file_paths): - """Get full shape and min available lat lon. To simplify processing - of full domain without needing to specify target and shape. - - Parameters - ---------- - file_paths : list - List of data file paths - - Returns - ------- - target : tuple - (lat, lon) for lower left corner - lat_lon : ndarray - Raw lat/lon array for entire domain - """ - return cls.get_lat_lon(file_paths, [slice(None), slice(None)]) - - @staticmethod - def get_closest_lat_lon(lat_lon, target): - """Get closest indices to target lat lon to use for lower left corner - of raster index - - Parameters - ---------- - lat_lon : ndarray - Array of lat/lon - (spatial_1, spatial_2, 2) - Last dimension in order of (lat, lon) - target : tuple - (lat, lon) for lower left corner - - Returns - ------- - row : int - row index for closest lat/lon to target lat/lon - col : int - col index for closest lat/lon to target lat/lon - """ - # shape of ll2 is (n, 2) where axis=1 is (lat, lon) - ll2 = np.vstack( - (lat_lon[..., 0].flatten(), lat_lon[..., 1].flatten()) - ).T - tree = KDTree(ll2) - _, i = tree.query(np.array(target)) - row, col = np.where( - (lat_lon[..., 0] == ll2[i, 0]) & (lat_lon[..., 1] == ll2[i, 1]) - ) - row = row[0] - col = col[0] - return row, col - - @classmethod - def compute_raster_index(cls, file_paths, target, grid_shape): - """Get raster index for a given target and shape - - Parameters - ---------- - file_paths : list - List of input data file paths - target : tuple - Target coordinate for lower left corner of extracted data - grid_shape : tuple - Shape out extracted data - - Returns - ------- - list - List of slices corresponding to extracted data region - """ - lat_lon = cls.get_lat_lon( - file_paths[:1], [slice(None), slice(None)], invert_lat=False - ) - cls._check_grid_extent(target, grid_shape, lat_lon) - - row, col = cls.get_closest_lat_lon(lat_lon, target) - - closest = tuple(lat_lon[row, col]) - logger.debug(f'Found closest coordinate {closest} to target={target}') - if np.hypot(closest[0] - target[0], closest[1] - target[1]) > 1: - msg = 'Closest coordinate to target is more than 1 degree away' - logger.warning(msg) - warnings.warn(msg) - - if cls.lats_are_descending(lat_lon): - row_end = row + 1 - row_start = row_end - grid_shape[0] - else: - row_end = row + grid_shape[0] - row_start = row - raster_index = [ - slice(row_start, row_end), - slice(col, col + grid_shape[1]), - ] - cls._validate_raster_shape(target, grid_shape, lat_lon, raster_index) - return raster_index - - @classmethod - def _check_grid_extent(cls, target, grid_shape, lat_lon): - """Make sure the requested target coordinate lies within the available - lat/lon grid. - - Parameters - ---------- - target : tuple - Target coordinate for lower left corner of extracted data - grid_shape : tuple - Shape out extracted data - lat_lon : ndarray - Array of lat/lon coordinates for entire available grid. Used to - check whether computed raster only includes coordinates within this - grid. - """ - min_lat = np.min(lat_lon[..., 0]) - min_lon = np.min(lat_lon[..., 1]) - max_lat = np.max(lat_lon[..., 0]) - max_lon = np.max(lat_lon[..., 1]) - logger.debug( - 'Calculating raster index from WRF file ' - f'for shape {grid_shape} and target {target}' - ) - logger.debug( - f'lat/lon (min, max): {min_lat}/{min_lon}, ' f'{max_lat}/{max_lon}' - ) - msg = ( - f'target {target} out of bounds with min lat/lon ' - f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}' - ) - assert ( - min_lat <= target[0] <= max_lat and min_lon <= target[1] <= max_lon - ), msg - - @classmethod - def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): - """Make sure the computed raster_index only includes coordinates within - the available grid - - Parameters - ---------- - target : tuple - Target coordinate for lower left corner of extracted data - grid_shape : tuple - Shape out extracted data - lat_lon : ndarray - Array of lat/lon coordinates for entire available grid. Used to - check whether computed raster only includes coordinates within this - grid. - raster_index : list - List of slices selecting region from entire available grid. - """ - if ( - raster_index[0].stop > lat_lon.shape[0] - or raster_index[1].stop > lat_lon.shape[1] - or raster_index[0].start < 0 - or raster_index[1].start < 0 - ): - msg = ( - f'Invalid target {target}, shape {grid_shape}, and raster ' - f'{raster_index} for data domain of size ' - f'{lat_lon.shape[:-1]} with lower left corner ' - f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' - f' and upper right corner ({np.max(lat_lon[..., 0])}, ' - f'{np.max(lat_lon[..., 1])}).' - ) - raise ValueError(msg) - - def get_raster_index(self): - """Get raster index for file data. Here we assume the list of paths in - file_paths all have data with the same spatial domain. We use the first - file in the list to compute the raster. - - Returns - ------- - raster_index : np.ndarray - 2D array of grid indices - """ - self.raster_file = ( - self.raster_file - if self.raster_file is None - else self.raster_file.replace('.txt', '.npy') - ) - if self.raster_file is not None and os.path.exists(self.raster_file): - logger.debug( - f'Loading raster index: {self.raster_file} ' - f'for {self.input_file_info}' - ) - raster_index = np.load(self.raster_file, allow_pickle=True) - raster_index = list(raster_index) - else: - check = self.grid_shape is not None and self.target is not None - msg = ( - 'Must provide raster file or shape + target to get ' - 'raster index' - ) - assert check, msg - raster_index = self.compute_raster_index( - self.file_paths, self.target, self.grid_shape - ) - logger.debug( - 'Found raster index with row, col slices: {}'.format( - raster_index - ) - ) - - if self.raster_file is not None: - basedir = os.path.dirname(self.raster_file) - if not os.path.exists(basedir): - os.makedirs(basedir) - logger.debug(f'Saving raster index: {self.raster_file}') - np.save(self.raster_file.replace('.txt', '.npy'), raster_index) - - return raster_index - - -class DataHandlerNCforCC(DataHandlerNC): - """Data Handler for NETCDF climate change data""" - - CHUNKS: ClassVar[dict] = {'time': 5, 'lat': 20, 'lon': 20} - """CHUNKS sets the chunk sizes to extract from the data in each dimension. - Chunk sizes that approximately match the data volume being extracted - typically results in the most efficient IO.""" - - def __init__( - self, - *args, - nsrdb_source_fp=None, - nsrdb_agg=1, - nsrdb_smoothing=0, - **kwargs, - ): - """Initialize NETCDF data handler for climate change data. - - Parameters - ---------- - *args : list - Same ordered required arguments as DataHandler parent class. - nsrdb_source_fp : str | None - Optional NSRDB source h5 file to retrieve clearsky_ghi from to - calculate CC clearsky_ratio along with rsds (ghi) from the CC - netcdf file. - nsrdb_agg : int - Optional number of NSRDB source pixels to aggregate clearsky_ghi - from to a single climate change netcdf pixel. This can be used if - the CC.nc data is at a much coarser resolution than the source - nsrdb data. - nsrdb_smoothing : float - Optional gaussian filter smoothing factor to smooth out - clearsky_ghi from high-resolution nsrdb source data. This is - typically done because spatially aggregated nsrdb data is still - usually rougher than CC irradiance data. - **kwargs : list - Same optional keyword arguments as DataHandler parent class. - """ - self._nsrdb_source_fp = nsrdb_source_fp - self._nsrdb_agg = nsrdb_agg - self._nsrdb_smoothing = nsrdb_smoothing - super().__init__(*args, **kwargs) - - @classmethod - def feature_registry(cls): - """Registry of methods for computing features or extracting renamed - features - - Returns - ------- - dict - Method registry - """ - registry = { - 'U_(.*)': 'ua_(.*)', - 'V_(.*)': 'va_(.*)', - 'Windspeed_(.*)m': WindspeedNC, - 'Winddirection_(.*)m': WinddirectionNC, - 'topography': 'orog', - 'relativehumidity_2m': 'hurs', - 'relativehumidity_min_2m': 'hursmin', - 'relativehumidity_max_2m': 'hursmax', - 'clearsky_ratio': ClearSkyRatioCC, - 'lat_lon': LatLonNC, - 'Pressure_(.*)': 'plev_(.*)', - 'Temperature_(.*)': TempNCforCC, - 'temperature_2m': Tas, - 'temperature_max_2m': TasMax, - 'temperature_min_2m': TasMin, - } - return registry - - @classmethod - def source_handler(cls, file_paths, **kwargs): - """Xarray data handler - - Note that xarray appears to treat open file handlers as singletons - within a threadpool, so its okay to open this source_handler without a - context handler or a .close() statement. - - Parameters - ---------- - file_paths : str | list - paths to data files - kwargs : dict - kwargs passed to source handler for data extraction. e.g. This - could be {'parallel': True, - 'chunks': {'south_north': 120, 'west_east': 120}} - which then gets passed to xr.open_mfdataset(file, **kwargs) - - Returns - ------- - data : xarray.Dataset - """ - default_kws = {'chunks': cls.CHUNKS} - default_kws.update(kwargs) - return xr.open_mfdataset(file_paths, **default_kws) - - def run_data_extraction(self): - """Run the raw dataset extraction process from disk to raw - un-manipulated datasets. - - Includes a special method to extract clearsky_ghi from a exogenous - NSRDB source h5 file (required to compute clearsky_ratio). - """ - get_clearsky = False - if 'clearsky_ghi' in self.raw_features: - get_clearsky = True - self._raw_features.remove('clearsky_ghi') - - super().run_data_extraction() - - if get_clearsky: - cs_ghi = self.get_clearsky_ghi() - - # clearsky ghi is extracted at the proper starting time index so - # the time chunks should start at 0 - tc0 = self.time_chunks[0].start - cs_ghi_time_chunks = [ - slice(tc.start - tc0, tc.stop - tc0, tc.step) - for tc in self.time_chunks - ] - for it, tslice in enumerate(cs_ghi_time_chunks): - self._raw_data[it]['clearsky_ghi'] = cs_ghi[..., tslice] - - self._raw_features.append('clearsky_ghi') - - def get_clearsky_ghi(self): - """Get clearsky ghi from an exogenous NSRDB source h5 file at the - target CC meta data and time index. - - Returns - ------- - cs_ghi : np.ndarray - Clearsky ghi (W/m2) from the nsrdb_source_fp h5 source file. Data - shape is (lat, lon, time) where time is daily average values. - """ - - msg = ( - 'Need nsrdb_source_fp input arg as a valid filepath to ' - 'retrieve clearsky_ghi (maybe for clearsky_ratio) but ' - 'received: {}'.format(self._nsrdb_source_fp) - ) - assert self._nsrdb_source_fp is not None, msg - assert os.path.exists(self._nsrdb_source_fp), msg - - msg = ( - 'Can only handle source CC data in hourly frequency but ' - 'received daily frequency of {}hrs (should be 24) ' - 'with raw time index: {}'.format( - self.time_freq_hours, self.raw_time_index - ) - ) - assert self.time_freq_hours == 24.0, msg - - msg = ( - 'Can only handle source CC data with temporal_slice.step == 1 ' - 'but received: {}'.format(self.temporal_slice.step) - ) - assert (self.temporal_slice.step is None) | ( - self.temporal_slice.step == 1 - ), msg - - with Resource(self._nsrdb_source_fp) as res: - ti_nsrdb = res.time_index - meta_nsrdb = res.meta - - ti_deltas = ti_nsrdb - np.roll(ti_nsrdb, 1) - ti_deltas_hours = pd.Series(ti_deltas).dt.total_seconds()[1:-1] / 3600 - time_freq = float(mode(ti_deltas_hours).mode) - t_start = self.temporal_slice.start or 0 - t_end_target = self.temporal_slice.stop or len(self.raw_time_index) - t_start = int(t_start * 24 * (1 / time_freq)) - t_end = int(t_end_target * 24 * (1 / time_freq)) - t_end = np.minimum(t_end, len(ti_nsrdb)) - t_slice = slice(t_start, t_end) - - # pylint: disable=E1136 - lat = self.lat_lon[:, :, 0].flatten() - lon = self.lat_lon[:, :, 1].flatten() - cc_meta = np.vstack((lat, lon)).T - - tree = KDTree(meta_nsrdb[['latitude', 'longitude']]) - _, i = tree.query(cc_meta, k=self._nsrdb_agg) - if len(i.shape) == 1: - i = np.expand_dims(i, axis=1) - - logger.info( - 'Extracting clearsky_ghi data from "{}" with time slice ' - '{} and {} locations with agg factor {}.'.format( - os.path.basename(self._nsrdb_source_fp), - t_slice, - i.shape[0], - i.shape[1], - ) - ) - - cs_shape = i.shape - with Resource(self._nsrdb_source_fp) as res: - cs_ghi = res['clearsky_ghi', t_slice, i.flatten()] - - cs_ghi = cs_ghi.reshape((len(cs_ghi), *cs_shape)) - cs_ghi = cs_ghi.mean(axis=-1) - - windows = np.array_split( - np.arange(len(cs_ghi)), len(cs_ghi) // (24 // time_freq) - ) - cs_ghi = [cs_ghi[window].mean(axis=0) for window in windows] - cs_ghi = np.vstack(cs_ghi) - cs_ghi = cs_ghi.reshape((len(cs_ghi), *tuple(self.grid_shape))) - cs_ghi = np.transpose(cs_ghi, axes=(1, 2, 0)) - - if self.invert_lat: - cs_ghi = cs_ghi[::-1] - - logger.info( - 'Smoothing nsrdb clearsky ghi with a factor of {}'.format( - self._nsrdb_smoothing - ) - ) - for iday in range(cs_ghi.shape[-1]): - cs_ghi[..., iday] = gaussian_filter( - cs_ghi[..., iday], self._nsrdb_smoothing, mode='nearest' - ) - - if cs_ghi.shape[-1] < t_end_target: - n = int(np.ceil(t_end_target / cs_ghi.shape[-1])) - cs_ghi = np.repeat(cs_ghi, n, axis=2) - cs_ghi = cs_ghi[..., :t_end_target] - - logger.info( - 'Reshaped clearsky_ghi data to final shape {} to ' - 'correspond with CC daily average data over source ' - 'temporal_slice {} with (lat, lon) grid shape of {}'.format( - cs_ghi.shape, self.temporal_slice, self.grid_shape - ) - ) - - return cs_ghi - - -class DataHandlerH5(DataHandler): - """DataHandler for H5 Data""" - - # the handler from rex to open h5 data. - REX_HANDLER = MultiFileWindX - - @classmethod - def source_handler(cls, file_paths, **kwargs): - """Rex data handler - - Note that xarray appears to treat open file handlers as singletons - within a threadpool, so its okay to open this source_handler without a - context handler or a .close() statement. - - Parameters - ---------- - file_paths : str | list - paths to data files - kwargs : dict - keyword arguments passed to source handler - - Returns - ------- - data : ResourceX - """ - return cls.REX_HANDLER(file_paths, **kwargs) - - @classmethod - def get_full_domain(cls, file_paths): - """Get target and shape for largest domain possible""" - msg = ( - 'You must either provide the target+shape inputs or an ' - 'existing raster_file input.' - ) - logger.error(msg) - raise ValueError(msg) - - @classmethod - def get_time_index(cls, file_paths, max_workers=None, **kwargs): - """Get time index from data files - - Parameters - ---------- - file_paths : list - path to data file - max_workers : int | None - placeholder to match signature - kwargs : dict - placeholder to match signature - - Returns - ------- - time_index : pd.DateTimeIndex - Time index from h5 source file(s) - """ - handle = cls.source_handler(file_paths) - time_index = handle.time_index - return time_index - - @classmethod - def feature_registry(cls): - """Registry of methods for computing features or extracting renamed - features - - Returns - ------- - dict - Method registry - """ - registry = { - 'BVF2_(.*)m': BVFreqSquaredH5, - 'BVF_MO_(.*)m': BVFreqMon, - 'U_(.*)m': UWind, - 'V_(.*)m': VWind, - 'lat_lon': LatLonH5, - 'REWS_(.*)m': Rews, - 'RMOL': 'inversemoninobukhovlength_2m', - 'P_(.*)m': 'pressure_(.*)m', - 'topography': TopoH5, - 'cloud_mask': CloudMaskH5, - 'clearsky_ratio': ClearSkyRatioH5, - } - return registry - - @classmethod - def extract_feature( - cls, - file_paths, - raster_index, - feature, - time_slice=slice(None), - **kwargs, - ): - """Extract single feature from data source - - Parameters - ---------- - file_paths : list - path to data file - raster_index : ndarray - Raster index array - feature : str - Feature to extract from data - time_slice : slice - slice of time to extract - kwargs : dict - keyword arguments passed to source handler - - Returns - ------- - ndarray - Data array for extracted feature - (spatial_1, spatial_2, temporal) - """ - logger.info(f'Extracting {feature} with kwargs={kwargs}') - handle = cls.source_handler(file_paths, **kwargs) - try: - fdata = handle[ - (feature, time_slice, *tuple([raster_index.flatten()])) - ] - except ValueError as e: - msg = f'{feature} cannot be extracted from source data' - logger.exception(msg) - raise ValueError(msg) from e - - fdata = fdata.reshape( - (-1, raster_index.shape[0], raster_index.shape[1]) - ) - fdata = np.transpose(fdata, (1, 2, 0)) - return fdata.astype(np.float32) - - def get_raster_index(self): - """Get raster index for file data. Here we assume the list of paths in - file_paths all have data with the same spatial domain. We use the first - file in the list to compute the raster. - - Returns - ------- - raster_index : np.ndarray - 2D array of grid indices - """ - if self.raster_file is not None and os.path.exists(self.raster_file): - logger.debug( - f'Loading raster index: {self.raster_file} ' - f'for {self.input_file_info}' - ) - raster_index = np.loadtxt(self.raster_file).astype(np.uint32) - else: - check = self.grid_shape is not None and self.target is not None - msg = ( - 'Must provide raster file or shape + target to get ' - 'raster index' - ) - assert check, msg - logger.debug( - 'Calculating raster index from WTK file ' - f'for shape {self.grid_shape} and target ' - f'{self.target}' - ) - handle = self.source_handler(self.file_paths[0]) - raster_index = handle.get_raster_index( - self.target, self.grid_shape, max_delta=self.max_delta - ) - if self.raster_file is not None: - basedir = os.path.dirname(self.raster_file) - if not os.path.exists(basedir): - os.makedirs(basedir) - logger.debug(f'Saving raster index: {self.raster_file}') - np.savetxt(self.raster_file, raster_index) - return raster_index - - -class DataHandlerH5WindCC(DataHandlerH5): - """Special data handling and batch sampling for h5 wtk or nsrdb data for - climate change applications""" - - # the handler from rex to open h5 data. - REX_HANDLER = MultiFileWindX - - # list of features / feature name patterns that are input to the generative - # model but are not part of the synthetic output and are not sent to the - # discriminator. These are case-insensitive and follow the Unix shell-style - # wildcard format. - TRAIN_ONLY_FEATURES = ( - 'temperature_max_*m', - 'temperature_min_*m', - 'relativehumidity_max_*m', - 'relativehumidity_min_*m', - ) - - def __init__(self, *args, **kwargs): - """ - Parameters - ---------- - *args : list - Same positional args as DataHandlerH5 - **kwargs : dict - Same keyword args as DataHandlerH5 - """ - sample_shape = kwargs.get('sample_shape', (10, 10, 24)) - t_shape = sample_shape[-1] - - if len(sample_shape) == 2: - logger.info( - 'Found 2D sample shape of {}. Adding spatial dim of 24'.format( - sample_shape - ) - ) - sample_shape = (*sample_shape, 24) - t_shape = sample_shape[-1] - kwargs['sample_shape'] = sample_shape - - if t_shape < 24 or t_shape % 24 != 0: - msg = ( - 'Climate Change DataHandler can only work with temporal ' - 'sample shapes that are one or more days of hourly data ' - '(e.g. 24, 48, 72...). The requested temporal sample ' - 'shape was: {}'.format(t_shape) - ) - logger.error(msg) - raise RuntimeError(msg) - - # validation splits not enabled for solar CC model. - kwargs['val_split'] = 0.0 - - super().__init__(*args, **kwargs) - - self.daily_data = None - self.daily_data_slices = None - self.run_daily_averages() - - def run_daily_averages(self): - """Calculate daily average data and store as attribute.""" - msg = ( - 'Data needs to be hourly with at least 24 hours, but data ' - 'shape is {}.'.format(self.data.shape) - ) - assert self.data.shape[2] % 24 == 0, msg - assert self.data.shape[2] > 24, msg - - n_data_days = int(self.data.shape[2] / 24) - daily_data_shape = ( - self.data.shape[0:2] + (n_data_days,) + (self.data.shape[3],) - ) - - logger.info( - 'Calculating daily average datasets for {} training ' - 'data days.'.format(n_data_days) - ) - - self.daily_data = np.zeros(daily_data_shape, dtype=np.float32) - - self.daily_data_slices = np.array_split( - np.arange(self.data.shape[2]), n_data_days - ) - self.daily_data_slices = [ - slice(x[0], x[-1] + 1) for x in self.daily_data_slices - ] - for idf, fname in enumerate(self.features): - for d, t_slice in enumerate(self.daily_data_slices): - if '_max_' in fname: - tmp = np.max(self.data[:, :, t_slice, idf], axis=2) - self.daily_data[:, :, d, idf] = tmp[:, :] - elif '_min_' in fname: - tmp = np.min(self.data[:, :, t_slice, idf], axis=2) - self.daily_data[:, :, d, idf] = tmp[:, :] - else: - tmp = daily_temporal_coarsening( - self.data[:, :, t_slice, idf], temporal_axis=2 - ) - self.daily_data[:, :, d, idf] = tmp[:, :, 0] - - logger.info( - 'Finished calculating daily average datasets for {} ' - 'training data days.'.format(n_data_days) - ) - - def _normalize_data(self, feature_index, mean, std): - """Normalize data with initialized mean and standard deviation for a - specific feature - - Parameters - ---------- - feature_index : int - index of feature to be normalized - mean : float32 - specified mean of associated feature - std : float32 - specificed standard deviation for associated feature - """ - super()._normalize_data(feature_index, mean, std) - self.daily_data[..., feature_index] -= mean - self.daily_data[..., feature_index] /= std - - @classmethod - def feature_registry(cls): - """Registry of methods for computing features - - Returns - ------- - dict - Method registry - """ - registry = { - 'U_(.*)m': UWind, - 'V_(.*)m': VWind, - 'lat_lon': LatLonH5, - 'topography': TopoH5, - 'temperature_max_(.*)m': 'temperature_(.*)m', - 'temperature_min_(.*)m': 'temperature_(.*)m', - 'relativehumidity_max_(.*)m': 'relativehumidity_(.*)m', - 'relativehumidity_min_(.*)m': 'relativehumidity_(.*)m', - } - return registry - - def get_observation_index(self): - """Randomly gets spatial sample and time sample - - Returns - ------- - obs_ind_hourly : tuple - Tuple of sampled spatial grid, time slice, and features indices. - Used to get single observation like self.data[observation_index]. - This is for hourly high-res data slicing. - obs_ind_daily : tuple - Same as obs_ind_hourly but the temporal index (i=2) is a slice of - the daily data (self.daily_data) with day integers. - """ - spatial_slice = uniform_box_sampler(self.data, self.sample_shape[:2]) - - n_days = int(self.sample_shape[2] / 24) - rand_day_ind = np.random.choice(len(self.daily_data_slices) - n_days) - t_slice_0 = self.daily_data_slices[rand_day_ind] - t_slice_1 = self.daily_data_slices[rand_day_ind + n_days - 1] - t_slice_hourly = slice(t_slice_0.start, t_slice_1.stop) - t_slice_daily = slice(rand_day_ind, rand_day_ind + n_days) - - obs_ind_hourly = tuple( - [*spatial_slice, t_slice_hourly, np.arange(len(self.features))] - ) - - obs_ind_daily = tuple( - [*spatial_slice, t_slice_daily, np.arange(len(self.features))] - ) - - return obs_ind_hourly, obs_ind_daily - - def get_next(self): - """Get data for observation using random observation index. Loops - repeatedly over randomized time index - - Returns - ------- - obs_hourly : np.ndarray - 4D array - (spatial_1, spatial_2, temporal_hourly, features) - obs_daily_avg : np.ndarray - 4D array but the temporal axis is temporal_hourly//24 - (spatial_1, spatial_2, temporal_daily, features) - """ - obs_ind_hourly, obs_ind_daily = self.get_observation_index() - self.current_obs_index = obs_ind_hourly - obs_hourly = self.data[obs_ind_hourly] - obs_daily_avg = self.daily_data[obs_ind_daily] - return obs_hourly, obs_daily_avg - - def split_data(self, data=None): - """Split time dimension into set of training indices and validation - indices. For NSRDB it makes sure that the splits happen at midnight. - - Parameters - ---------- - data : np.ndarray - 4D array of high res data - (spatial_1, spatial_2, temporal, features) - - Returns - ------- - data : np.ndarray - (spatial_1, spatial_2, temporal, features) - Training data fraction of initial data array. Initial data array is - overwritten by this new data array. - val_data : np.ndarray - (spatial_1, spatial_2, temporal, features) - Validation data fraction of initial data array. - """ - - if data is not None: - self.data = data - - midnight_ilocs = np.where( - (self.time_index.hour == 0) - & (self.time_index.minute == 0) - & (self.time_index.second == 0) - )[0] - - n_val_obs = int(np.ceil(self.val_split * len(midnight_ilocs))) - val_split_index = midnight_ilocs[n_val_obs] - - self.val_data = self.data[:, :, slice(None, val_split_index), :] - self.data = self.data[:, :, slice(val_split_index, None), :] - - self.val_time_index = self.time_index[slice(None, val_split_index)] - self.time_index = self.time_index[slice(val_split_index, None)] - - return self.data, self.val_data - - -class DataHandlerH5SolarCC(DataHandlerH5WindCC): - """Special data handling and batch sampling for h5 NSRDB solar data for - climate change applications""" - - # the handler from rex to open h5 data. - REX_HANDLER = MultiFileNSRDBX - - # list of features / feature name patterns that are input to the generative - # model but are not part of the synthetic output and are not sent to the - # discriminator. These are case-insensitive and follow the Unix shell-style - # wildcard format. - TRAIN_ONLY_FEATURES = ('U*', 'V*', 'topography') - - def __init__(self, *args, **kwargs): - """ - Parameters - ---------- - *args : list - Same positional args as DataHandlerH5 - **kwargs : dict - Same keyword args as DataHandlerH5 - """ - - args = copy.deepcopy(args) # safe copy for manipulation - required = ['ghi', 'clearsky_ghi', 'clearsky_ratio'] - missing = [dset for dset in required if dset not in args[1]] - if any(missing): - msg = ( - 'Cannot initialize DataHandlerH5SolarCC without required ' - 'features {}. All three are necessary to get the daily ' - 'average clearsky ratio (ghi sum / clearsky ghi sum), even ' - 'though only the clearsky ratio will be passed to the ' - 'GAN.'.format(required) - ) - logger.error(msg) - raise KeyError(msg) - - super().__init__(*args, **kwargs) - - @classmethod - def feature_registry(cls): - """Registry of methods for computing features - - Returns - ------- - dict - Method registry - """ - registry = { - 'U': UWind, - 'V': VWind, - 'windspeed': 'wind_speed', - 'winddirection': 'wind_direction', - 'lat_lon': LatLonH5, - 'cloud_mask': CloudMaskH5, - 'clearsky_ratio': ClearSkyRatioH5, - 'topography': TopoH5, - } - return registry - - def run_daily_averages(self): - """Calculate daily average data and store as attribute. - - Note that the H5 clearsky ratio feature requires special logic to match - the climate change dataset of daily average GHI / daily average CS_GHI. - This target climate change dataset is not equivalent to the average of - instantaneous hourly clearsky ratios - """ - - msg = ( - 'Data needs to be hourly with at least 24 hours, but data ' - 'shape is {}.'.format(self.data.shape) - ) - assert self.data.shape[2] % 24 == 0, msg - assert self.data.shape[2] > 24, msg - - n_data_days = int(self.data.shape[2] / 24) - daily_data_shape = ( - self.data.shape[0:2] + (n_data_days,) + (self.data.shape[3],) - ) - - logger.info( - 'Calculating daily average datasets for {} training ' - 'data days.'.format(n_data_days) - ) - - self.daily_data = np.zeros(daily_data_shape, dtype=np.float32) - - self.daily_data_slices = np.array_split( - np.arange(self.data.shape[2]), n_data_days - ) - self.daily_data_slices = [ - slice(x[0], x[-1] + 1) for x in self.daily_data_slices - ] - - i_ghi = self.features.index('ghi') - i_cs = self.features.index('clearsky_ghi') - i_ratio = self.features.index('clearsky_ratio') - - for d, t_slice in enumerate(self.daily_data_slices): - for idf in range(self.data.shape[-1]): - self.daily_data[:, :, d, idf] = daily_temporal_coarsening( - self.data[:, :, t_slice, idf], temporal_axis=2 - )[:, :, 0] - - # note that this ratio of daily irradiance sums is not the same as - # the average of hourly ratios. - total_ghi = np.nansum(self.data[:, :, t_slice, i_ghi], axis=2) - total_cs_ghi = np.nansum(self.data[:, :, t_slice, i_cs], axis=2) - avg_cs_ratio = total_ghi / total_cs_ghi - self.daily_data[:, :, d, i_ratio] = avg_cs_ratio - - # remove ghi and clearsky ghi from feature set. These shouldn't be used - # downstream for solar cc and keeping them confuses the batch handler - logger.info( - 'Finished calculating daily average clearsky_ratio, ' - 'removing ghi and clearsky_ghi from the ' - 'DataHandlerH5SolarCC feature list.' - ) - ifeats = np.array( - [i for i in range(len(self.features)) if i not in (i_ghi, i_cs)] - ) - self.data = self.data[..., ifeats] - self.daily_data = self.daily_data[..., ifeats] - self.features.remove('ghi') - self.features.remove('clearsky_ghi') - - logger.info( - 'Finished calculating daily average datasets for {} ' - 'training data days.'.format(n_data_days) - ) - - -# pylint: disable=W0223 -class DataHandlerDC(DataHandler): - """Data-centric data handler""" - - def get_observation_index( - self, temporal_weights=None, spatial_weights=None - ): - """Randomly gets weighted spatial sample and time sample - - Parameters - ---------- - temporal_weights : array - Weights used to select time slice - (n_time_chunks) - spatial_weights : array - Weights used to select spatial chunks - (n_lat_chunks * n_lon_chunks) - - Returns - ------- - observation_index : tuple - Tuple of sampled spatial grid, time slice, and features indices. - Used to get single observation like self.data[observation_index] - """ - if spatial_weights is not None: - spatial_slice = weighted_box_sampler( - self.data, self.sample_shape[:2], weights=spatial_weights - ) - else: - spatial_slice = uniform_box_sampler( - self.data, self.sample_shape[:2] - ) - if temporal_weights is not None: - temporal_slice = weighted_time_sampler( - self.data, self.sample_shape[2], weights=temporal_weights - ) - else: - temporal_slice = uniform_time_sampler( - self.data, self.sample_shape[2] - ) - - return tuple( - [*spatial_slice, temporal_slice, np.arange(len(self.features))] - ) - - def get_next(self, temporal_weights=None, spatial_weights=None): - """Get data for observation using weighted random observation index. - Loops repeatedly over randomized time index. - - Parameters - ---------- - temporal_weights : array - Weights used to select time slice - (n_time_chunks) - spatial_weights : array - Weights used to select spatial chunks - (n_lat_chunks * n_lon_chunks) - - Returns - ------- - observation : np.ndarray - 4D array - (spatial_1, spatial_2, temporal, features) - """ - self.current_obs_index = self.get_observation_index( - temporal_weights=temporal_weights, spatial_weights=spatial_weights - ) - observation = self.data[self.current_obs_index] - return observation - - -class DataHandlerDCforNC(DataHandlerNC, DataHandlerDC): - """Data centric data handler for NETCDF files""" - - -class DataHandlerDCforH5(DataHandlerH5, DataHandlerDC): - """Data centric data handler for H5 files""" diff --git a/sup3r/preprocessing/data_handling/__init__.py b/sup3r/preprocessing/data_handling/__init__.py new file mode 100644 index 000000000..ba5802ea0 --- /dev/null +++ b/sup3r/preprocessing/data_handling/__init__.py @@ -0,0 +1,15 @@ +"""Collection of data handlers""" + +from .dual_data_handling import DualDataHandler +from .exogenous_data_handling import ExogenousDataHandler +from .h5_data_handling import ( + DataHandlerDCforH5, + DataHandlerH5, + DataHandlerH5SolarCC, + DataHandlerH5WindCC, +) +from .nc_data_handling import ( + DataHandlerDCforNC, + DataHandlerNC, + DataHandlerNCforCC, +) diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py new file mode 100644 index 000000000..bb5c15041 --- /dev/null +++ b/sup3r/preprocessing/data_handling/base.py @@ -0,0 +1,1727 @@ +"""Base data handling classes. +@author: bbenton +""" +import logging +import os +import pickle +import warnings +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime as dt +from fnmatch import fnmatch + +import numpy as np +import pandas as pd +from rex.utilities import log_mem +from rex.utilities.fun_utils import get_fun_call_str +from scipy.spatial import KDTree + +from sup3r.bias.bias_transforms import get_spatial_bc_factors +from sup3r.preprocessing.data_handling.mixin import ( + InputMixIn, + TrainingPrepMixIn, +) +from sup3r.preprocessing.feature_handling import ( + BVFreqMon, + BVFreqSquaredNC, + Feature, + FeatureHandler, + InverseMonNC, + LatLonNC, + PotentialTempNC, + PressureNC, + Rews, + Shear, + TempNC, + UWind, + VWind, + WinddirectionNC, + WindspeedNC, +) +from sup3r.utilities import ModuleName +from sup3r.utilities.cli import BaseCLI +from sup3r.utilities.interpolation import Interpolator +from sup3r.utilities.utilities import ( + estimate_max_workers, + get_chunk_slices, + get_raster_shape, + np_to_pd_times, + spatial_coarsening, + uniform_box_sampler, + uniform_time_sampler, + weighted_box_sampler, + weighted_time_sampler, +) + +np.random.seed(42) + +logger = logging.getLogger(__name__) + + +class DataHandler(FeatureHandler, InputMixIn, TrainingPrepMixIn): + """Sup3r data handling and extraction for low-res source data or for + artificially coarsened high-res source data for training. + + The sup3r data handler class is based on a 4D numpy array of shape: + (spatial_1, spatial_2, temporal, features) + """ + + # list of features / feature name patterns that are input to the generative + # model but are not part of the synthetic output and are not sent to the + # discriminator. These are case-insensitive and follow the Unix shell-style + # wildcard format. + TRAIN_ONLY_FEATURES = ( + 'BVF*', + 'inversemoninobukhovlength_*', + 'RMOL', + 'topography', + ) + + def __init__(self, + file_paths, + features, + target=None, + shape=None, + max_delta=20, + temporal_slice=slice(None, None, 1), + hr_spatial_coarsen=None, + time_roll=0, + val_split=0.0, + sample_shape=(10, 10, 1), + raster_file=None, + raster_index=None, + shuffle_time=False, + time_chunk_size=None, + cache_pattern=None, + overwrite_cache=False, + overwrite_ti_cache=False, + load_cached=False, + train_only_features=None, + handle_features=None, + single_ts_files=None, + mask_nan=False, + worker_kwargs=None, + res_kwargs=None): + """ + Parameters + ---------- + file_paths : str | list + A single source h5 wind file to extract raster data from or a list + of netcdf files with identical grid. The string can be a unix-style + file path which will be passed through glob.glob + features : list + list of features to extract from the provided data + target : tuple + (lat, lon) lower left corner of raster. Either need target+shape or + raster_file. + shape : tuple + (rows, cols) grid size. Either need target+shape or raster_file. + max_delta : int, optional + Optional maximum limit on the raster shape that is retrieved at + once. If shape is (20, 20) and max_delta=10, the full raster will + be retrieved in four chunks of (10, 10). This helps adapt to + non-regular grids that curve over large distances, by default 20 + temporal_slice : slice + Slice specifying extent and step of temporal extraction. e.g. + slice(start, stop, time_pruning). If equal to slice(None, None, 1) + the full time dimension is selected. + hr_spatial_coarsen : int | None + Optional input to coarsen the high-resolution spatial field. This + can be used if (for example) you have 2km source data, but you want + the final high res prediction target to be 4km resolution, then + hr_spatial_coarsen would be 2 so that the GAN is trained on + aggregated 4km high-res data. + time_roll : int + The number of places by which elements are shifted in the time + axis. Can be used to convert data to different timezones. This is + passed to np.roll(a, time_roll, axis=2) and happens AFTER the + temporal_slice operation. + val_split : float32 + Fraction of data to store for validation + sample_shape : tuple + Size of spatial and temporal domain used in a single high-res + observation for batching + raster_file : str | None + File for raster_index array for the corresponding target and shape. + If specified the raster_index will be loaded from the file if it + exists or written to the file if it does not yet exist. If None and + raster_index is not provided raster_index will be calculated + directly. Either need target+shape, raster_file, or raster_index + input. + raster_index : list + List of tuples or slices. Used as an alternative to computing the + raster index from target+shape or loading the raster index from + file + shuffle_time : bool + Whether to shuffle time indices before validation split + time_chunk_size : int + Size of chunks to split time dimension into for parallel data + extraction. If running in serial this can be set to the size of the + full time index for best performance. + cache_pattern : str | None + Pattern for files for saving feature data. e.g. + file_path_{feature}.pkl. Each feature will be saved to a file with + the feature name replaced in cache_pattern. If not None + feature arrays will be saved here and not stored in self.data until + load_cached_data is called. The cache_pattern can also include + {shape}, {target}, {times} which will help ensure unique cache + files for complex problems. + overwrite_cache : bool + Whether to overwrite any previously saved cache files. + overwrite_ti_cache : bool + Whether to overwrite any previously saved time index cache files. + overwrite_ti_cache : bool + Whether to overwrite saved time index cache files. + load_cached : bool + Whether to load data from cache files + train_only_features : list | tuple | None + List of feature names or patt*erns that should only be included in + the training set and not the output. If None (default), this will + default to the class TRAIN_ONLY_FEATURES attribute. + handle_features : list | None + Optional list of features which are available in the provided data. + Providing this eliminates the need for an initial search of + available features prior to data extraction. + single_ts_files : bool | None + Whether input files are single time steps or not. If they are this + enables some reduced computation. If None then this will be + determined from file_paths directly. + mask_nan : bool + Flag to mask out (remove) any timesteps with NaN data from the + source dataset. This is False by default because it can create + discontinuities in the timeseries. + worker_kwargs : dict | None + Dictionary of worker values. Can include max_workers, + extract_workers, compute_workers, load_workers, norm_workers, + and ti_workers. Each argument needs to be an integer or None. + + The value of `max workers` will set the value of all other worker + args. If max_workers == 1 then all processes will be serialized. If + max_workers == None then other worker args will use their own + provided values. + + `extract_workers` is the max number of workers to use for + extracting features from source data. If None it will be estimated + based on memory limits. If 1 processes will be serialized. + `compute_workers` is the max number of workers to use for computing + derived features from raw features in source data. `load_workers` + is the max number of workers to use for loading cached feature + data. `norm_workers` is the max number of workers to use for + normalizing feature data. `ti_workers` is the max number of + workers to use to get full time index. Useful when there are many + input files each with a single time step. If this is greater than + one, time indices for input files will be extracted in parallel + and then concatenated to get the full time index. If input files + do not all have time indices or if there are few input files this + should be set to one. + res_kwargs : dict | None + kwargs passed to source handler for data extraction. e.g. This + could be {'parallel': True, + 'concat_dim': 'Time', + 'combine': 'nested', + 'chunks': {'south_north': 120, 'west_east': 120}} + which then gets passed to xr.open_mfdataset(file, **res_kwargs) + """ + InputMixIn.__init__(self, + target=target, + shape=shape, + raster_file=raster_file, + raster_index=raster_index, + temporal_slice=temporal_slice) + + self.file_paths = file_paths + self.features = ( + features if isinstance(features, (list, tuple)) else [features] + ) + self.val_time_index = None + self.max_delta = max_delta + self.val_split = val_split + self.sample_shape = sample_shape + self.hr_spatial_coarsen = hr_spatial_coarsen or 1 + self.time_roll = time_roll + self.shuffle_time = shuffle_time + self.current_obs_index = None + self.overwrite_cache = overwrite_cache + self.overwrite_ti_cache = overwrite_ti_cache + self.load_cached = load_cached + self.data = None + self.val_data = None + self.res_kwargs = res_kwargs or {} + self._single_ts_files = single_ts_files + self._cache_pattern = cache_pattern + self._train_only_features = train_only_features + self._time_chunk_size = time_chunk_size + self._handle_features = handle_features + self._cache_files = None + self._extract_features = None + self._noncached_features = None + self._raw_features = None + self._raw_data = {} + self._time_chunks = None + self.worker_kwargs = worker_kwargs or {} + self.max_workers = self.worker_kwargs.get('max_workers', None) + self._ti_workers = self.worker_kwargs.get('ti_workers', None) + self._extract_workers = self.worker_kwargs.get('extract_workers', None) + self._norm_workers = self.worker_kwargs.get('norm_workers', None) + self._load_workers = self.worker_kwargs.get('load_workers', None) + self._compute_workers = self.worker_kwargs.get('compute_workers', None) + self._worker_attrs = ['_ti_workers', + '_norm_workers', + '_compute_workers', + '_extract_workers', + '_load_workers'] + + self.preflight() + + overwrite = (self.overwrite_cache + and self.cache_files is not None + and all(os.path.exists(fp) for fp in self.cache_files)) + + if self.try_load and self.load_cached: + logger.info( + f'All {self.cache_files} exist. Loading from cache ' + f'instead of extracting from source files.') + self.load_cached_data() + + elif self.try_load and not self.load_cached: + self.clear_data() + logger.info( + f'All {self.cache_files} exist. Call ' + 'load_cached_data() or use load_cache=True to load ' + 'this data from cache files.') + else: + if overwrite: + logger.info( + f'{self.cache_files} exists but overwrite_cache ' + 'is set to True. Proceeding with extraction.') + + self._raster_size_check() + self._run_data_init_if_needed() + + if self._cache_pattern is not None: + self.cache_data(self.cache_files) + self.data = None if not self.load_cached else self.data + + self._val_split_check() + + if mask_nan and self.data is not None: + nan_mask = np.isnan(self.data).any(axis=(0, 1, 3)) + logger.info( + 'Removing {} out of {} timesteps due to NaNs'.format( + nan_mask.sum(), self.data.shape[2] + ) + ) + self.data = self.data[:, :, ~nan_mask, :] + + logger.info('Finished intializing DataHandler.') + log_mem(logger, log_level='INFO') + + @property + def try_load(self): + """Check if we should try to load cache""" + return self._should_load_cache( + self._cache_pattern, self.cache_files, self.overwrite_cache) + + def check_clear_data(self): + """Check if data is cached and clear data if not load_cached""" + if self._cache_pattern is not None and not self.load_cached: + self.data = None + self.val_data = None + + def _run_data_init_if_needed(self): + """Check if any features need to be extracted and proceed with data + extraction""" + if any(self.features): + self.data = self.run_all_data_init() + nan_perc = 100 * np.isnan(self.data).sum() / self.data.size + if nan_perc > 0: + msg = 'Data has {:.2f}% NaN values!'.format(nan_perc) + logger.warning(msg) + warnings.warn(msg) + + def _raster_size_check(self): + """Check if the sample_shape is larger than the requested raster + size""" + bad_shape = ( + self.sample_shape[0] > self.grid_shape[0] + and self.sample_shape[1] > self.grid_shape[1]) + if bad_shape: + msg = ( + f'spatial_sample_shape {self.sample_shape[:2]} is ' + f'larger than the raster size {self.grid_shape}') + logger.warning(msg) + warnings.warn(msg) + + def _val_split_check(self): + """Check if val_split > 0 and split data into validation and training. + Make sure validation data is larger than sample_shape""" + + if self.data is not None and self.val_split > 0.0: + self.data, self.val_data = self.split_data( + val_split=self.val_split, shuffle_time=self.shuffle_time + ) + msg = ( + f'Validation data has shape={self.val_data.shape} ' + f'and sample_shape={self.sample_shape}. Use a smaller ' + 'sample_shape and/or larger val_split.') + check = any( + val_size < samp_size + for val_size, samp_size in zip( + self.val_data.shape, self.sample_shape)) + if check: + logger.warning(msg) + warnings.warn(msg) + + @classmethod + @abstractmethod + def get_full_domain(cls, file_paths): + """Get target and shape for full domain""" + + def clear_data(self): + """Free memory used for data arrays""" + self.data = None + self.val_data = None + + @classmethod + @abstractmethod + def source_handler(cls, file_paths, **kwargs): + """Handle for source data. Uses xarray, ResourceX, etc. + + NOTE: that xarray appears to treat open file handlers as singletons + within a threadpool, so its okay to open this source_handler without a + context handler or a .close() statement. + """ + + @property + def attrs(self): + """Get atttributes of input data + + Returns + ------- + dict + Dictionary of attributes + """ + handle = self.source_handler(self.file_paths) + desc = handle.attrs + return desc + + @property + def train_only_features(self): + """Features to use for training only and not output""" + if self._train_only_features is None: + self._train_only_features = self.TRAIN_ONLY_FEATURES + return self._train_only_features + + @property + def extract_workers(self): + """Get upper bound for extract workers based on memory limits. Used to + extract data from source dataset. The max number of extract workers + is number of time chunks * number of features""" + proc_mem = 4 * self.grid_mem * len(self.time_index) + proc_mem /= len(self.time_chunks) + n_procs = len(self.time_chunks) * len(self.extract_features) + n_procs = int(np.ceil(n_procs)) + extract_workers = estimate_max_workers( + self._extract_workers, proc_mem, n_procs) + return extract_workers + + @property + def compute_workers(self): + """Get upper bound for compute workers based on memory limits. Used to + compute derived features from source dataset.""" + proc_mem = int( + np.ceil( + len(self.extract_features) + / np.maximum(len(self.derive_features), 1))) + proc_mem *= 4 * self.grid_mem * len(self.time_index) + proc_mem /= len(self.time_chunks) + n_procs = len(self.time_chunks) * len(self.derive_features) + n_procs = int(np.ceil(n_procs)) + compute_workers = estimate_max_workers( + self._compute_workers, proc_mem, n_procs) + return compute_workers + + @property + def load_workers(self): + """Get upper bound on load workers based on memory limits. Used to load + cached data.""" + proc_mem = 2 * self.feature_mem + n_procs = 1 + if self.cache_files is not None: + n_procs = len(self.cache_files) + load_workers = estimate_max_workers(self._load_workers, proc_mem, + n_procs) + return load_workers + + @property + def norm_workers(self): + """Get upper bound on workers used for normalization.""" + if self.data is not None: + norm_workers = estimate_max_workers( + self._norm_workers, 2 * self.feature_mem, self.shape[-1]) + else: + norm_workers = self._norm_workers + return norm_workers + + @property + def time_chunks(self): + """Get time chunks which will be extracted from source data + + Returns + ------- + _time_chunks : list + List of time chunks used to split up source data time dimension + so that each chunk can be extracted individually + """ + if self._time_chunks is None: + if self.is_time_independent: + self._time_chunks = [slice(None)] + else: + self._time_chunks = get_chunk_slices( + len(self.raw_time_index), + self.time_chunk_size, + self.temporal_slice, + ) + return self._time_chunks + + @property + def is_time_independent(self): + """Get whether source data files are time independent""" + return self.raw_time_index[0] is None + + @property + def n_tsteps(self): + """Get number of time steps to extract""" + if self.is_time_independent: + return 1 + else: + return len(self.raw_time_index[self.temporal_slice]) + + @property + def time_chunk_size(self): + """Get upper bound on time chunk size based on memory limits""" + if self._time_chunk_size is None: + step_mem = self.feature_mem * len(self.extract_features) + step_mem /= len(self.time_index) + if step_mem == 0: + self._time_chunk_size = self.n_tsteps + else: + self._time_chunk_size = np.min( + [int(1e9 / step_mem), self.n_tsteps]) + logger.info( + 'time_chunk_size arg not specified. Using ' + f'{self._time_chunk_size}.') + return self._time_chunk_size + + @property + def cache_files(self): + """Cache files for storing extracted data""" + if self._cache_files is None: + self._cache_files = self.get_cache_file_names(self.cache_pattern) + return self._cache_files + + @property + def raster_index(self): + """Raster index property""" + if self._raster_index is None: + self._raster_index = self.get_raster_index() + return self._raster_index + + @raster_index.setter + def raster_index(self, raster_index): + """Update raster index property""" + self._raster_index = raster_index + + @classmethod + def get_handle_features(cls, file_paths): + """Get all available features in input data + + Parameters + ---------- + file_paths : list + List of input file paths + + Returns + ------- + handle_features : list + List of available input features + """ + handle_features = [] + for f in file_paths: + handle = cls.source_handler([f]) + for r in handle: + handle_features.append(Feature.get_basename(r)) + return list(set(handle_features)) + + @property + def handle_features(self): + """All features available in raw input""" + if self._handle_features is None: + self._handle_features = self.get_handle_features(self.file_paths) + return self._handle_features + + @property + def noncached_features(self): + """Get list of features needing extraction or derivation""" + if self._noncached_features is None: + self._noncached_features = self.check_cached_features( + self.features, + cache_files=self.cache_files, + overwrite_cache=self.overwrite_cache, + load_cached=self.load_cached, + ) + return self._noncached_features + + @property + def extract_features(self): + """Features to extract directly from the source handler""" + lower_features = [f.lower() for f in self.handle_features] + return [f for f in self.raw_features + if self.lookup(f, 'compute') is None + or Feature.get_basename(f.lower()) in lower_features] + + @property + def derive_features(self): + """List of features which need to be derived from other features""" + derive_features = [ + f for f in set( + list(self.noncached_features) + list(self.extract_features)) + if f not in self.extract_features] + return derive_features + + @property + def cached_features(self): + """List of features which have been requested but have been determined + not to need extraction. Thus they have been cached already.""" + return [f for f in self.features if f not in self.noncached_features] + + @property + def raw_features(self): + """Get list of features needed for computations""" + if self._raw_features is None: + self._raw_features = self.get_raw_feature_list( + self.noncached_features, self.handle_features) + return self._raw_features + + @property + def output_features(self): + """Get a list of features that should be output by the generative model + corresponding to the features in the high res batch array.""" + out = [] + for feature in self.features: + ignore = any( + fnmatch(feature.lower(), pattern.lower()) + for pattern in self.train_only_features) + if not ignore: + out.append(feature) + return out + + @property + def grid_mem(self): + """Get memory used by a feature at a single time step + + Returns + ------- + int + Number of bytes for a single feature array at a single time step + """ + grid_mem = np.product(self.grid_shape) + # assuming feature arrays are float32 (4 bytes) + return 4 * grid_mem + + @property + def feature_mem(self): + """Number of bytes for a single feature array. Used to estimate + max_workers. + + Returns + ------- + int + Number of bytes for a single feature array + """ + feature_mem = self.grid_mem * len(self.time_index) + return feature_mem + + def preflight(self): + """Run some preflight checks and verify that the inputs are valid""" + + self.cap_worker_args(self.max_workers) + + if len(self.sample_shape) == 2: + logger.info( + 'Found 2D sample shape of {}. Adding temporal dim of 1'.format( + self.sample_shape)) + self.sample_shape = (*self.sample_shape, 1) + + start = self.temporal_slice.start + stop = self.temporal_slice.stop + n_steps = self.n_tsteps + msg = ( + f'Temporal slice step ({self.temporal_slice.step}) does not ' + f'evenly divide the number of time steps ({n_steps})') + check = self.temporal_slice.step is None + check = check or n_steps % self.temporal_slice.step == 0 + if not check: + logger.warning(msg) + warnings.warn(msg) + + msg = ( + f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger ' + 'than the number of time steps in the raw data ' + f'({len(self.raw_time_index)}).') + if len(self.raw_time_index) < self.sample_shape[2]: + logger.warning(msg) + warnings.warn(msg) + + msg = ( + f'The requested time slice {self.temporal_slice} conflicts ' + f'with the number of time steps ({len(self.raw_time_index)}) ' + 'in the raw data') + t_slice_is_subset = start is not None and stop is not None + good_subset = ( + t_slice_is_subset + and (stop - start <= len(self.raw_time_index)) + and stop <= len(self.raw_time_index) + and start <= len(self.raw_time_index)) + if t_slice_is_subset and not good_subset: + logger.error(msg) + raise RuntimeError(msg) + + msg = ( + f'Initializing DataHandler {self.input_file_info}. ' + f'Getting temporal range {self.time_index[0]!s} to ' + f'{self.time_index[-1]!s} (inclusive) ' + f'based on temporal_slice {self.temporal_slice}') + logger.info(msg) + + logger.info( + f'Using max_workers={self.max_workers}, ' + f'norm_workers={self.norm_workers}, ' + f'extract_workers={self.extract_workers}, ' + f'compute_workers={self.compute_workers}, ' + f'load_workers={self.load_workers}, ' + f'ti_workers={self.ti_workers}') + + @classmethod + def get_lat_lon(cls, file_paths, raster_index, invert_lat=False): + """Get lat/lon grid for requested target and shape + + Parameters + ---------- + file_paths : list + path to data file + raster_index : ndarray | list + Raster index array or list of slices + invert_lat : bool + Flag to invert data along the latitude axis. Wrf data tends to use + an increasing ordering for latitude while wtk uses a decreasing + ordering. + + Returns + ------- + ndarray + (spatial_1, spatial_2, 2) Lat/Lon array with same ordering in last + dimension + """ + lat_lon = cls.lookup('lat_lon', 'compute')(file_paths, raster_index) + if invert_lat: + lat_lon = lat_lon[::-1] + # put angle betwen -180 and 180 + lat_lon[..., 1] = (lat_lon[..., 1] + 180) % 360 - 180 + return lat_lon + + @classmethod + def get_node_cmd(cls, config): + """Get a CLI call to initialize DataHandler and cache data. + + Parameters + ---------- + config : dict + sup3r data handler config with all necessary args and kwargs to + initialize DataHandler and run data extraction. + """ + + import_str = ( + 'from sup3r.preprocessing.data_handling ' + f'import {cls.__name__};\n' + 'import time;\n' + 'from sup3r.pipeline import Status;\n' + 'from rex import init_logger;\n') + + dh_init_str = get_fun_call_str(cls, config) + + log_file = config.get('log_file', None) + log_level = config.get('log_level', 'INFO') + log_arg_str = f'"sup3r", log_level="{log_level}"' + if log_file is not None: + log_arg_str += f', log_file="{log_file}"' + + cache_check = config.get('cache_pattern', False) + + msg = 'No cache file prefix provided.' + if not cache_check: + logger.warning(msg) + warnings.warn(msg) + + cmd = ( + f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"data_handler = {dh_init_str};\n" + "t_elap = time.time() - t0;\n") + + cmd = BaseCLI.add_status_cmd(config, ModuleName.DATA_EXTRACT, cmd) + + cmd += ";\'\n" + return cmd.replace('\\', '/') + + def get_cache_file_names(self, + cache_pattern, + grid_shape=None, + time_index=None, + target=None, + features=None): + """Get names of cache files from cache_pattern and feature names + + Parameters + ---------- + cache_pattern : str + Pattern to use for cache file names + grid_shape : tuple + Shape of grid to use for cache file naming + time_index : list | pd.DatetimeIndex + Time index to use for cache file naming + target : tuple + Target to use for cache file naming + features : list + List of features to use for cache file naming + + Returns + ------- + list + List of cache file names + """ + grid_shape = grid_shape if grid_shape is not None else self.grid_shape + time_index = time_index if time_index is not None else self.time_index + target = target if target is not None else self.target + features = features if features is not None else self.features + + return self._get_cache_file_names( + cache_pattern, grid_shape, time_index, target, features) + + def unnormalize(self, means, stds): + """Remove normalization from stored means and stds""" + self._unnormalize(self.data, self.val_data, means, stds) + + def normalize(self, means, stds): + """Normalize all data features + + Parameters + ---------- + means : np.ndarray + dimensions (features) + array of means for all features with same ordering as data features + stds : np.ndarray + dimensions (features) + array of means for all features with same ordering as data features + """ + max_workers = self.norm_workers + self._normalize( + self.data, self.val_data, means, stds, max_workers=max_workers) + + def get_next(self): + """Get data for observation using random observation index. Loops + repeatedly over randomized time index + + Returns + ------- + observation : np.ndarray + 4D array + (spatial_1, spatial_2, temporal, features) + """ + self.current_obs_index = self._get_observation_index( + self.data, self.sample_shape) + observation = self.data[self.current_obs_index] + return observation + + def split_data(self, data=None, val_split=0.0, shuffle_time=False): + """Split time dimension into set of training indices and validation + indices + + Parameters + ---------- + data : np.ndarray + 4D array of high res data + (spatial_1, spatial_2, temporal, features) + val_split : float + Fraction of data to separate for validation. + shuffle_time : bool + Whether to shuffle time or not. + + Returns + ------- + data : np.ndarray + (spatial_1, spatial_2, temporal, features) + Training data fraction of initial data array. Initial data array is + overwritten by this new data array. + val_data : np.ndarray + (spatial_1, spatial_2, temporal, features) + Validation data fraction of initial data array. + """ + data = data if data is not None else self.data + + assert len(self.time_index) == self.data.shape[-2] + + train_indices, val_indices = self._split_data_indices( + data, val_split=val_split, shuffle_time=shuffle_time) + self.val_data = self.data[:, :, val_indices, :] + self.data = self.data[:, :, train_indices, :] + + self.val_time_index = self.time_index[val_indices] + self.time_index = self.time_index[train_indices] + + return self.data, self.val_data + + @property + def shape(self): + """Full data shape + + Returns + ------- + shape : tuple + Full data shape + (spatial_1, spatial_2, temporal, features) + """ + return self.data.shape + + def cache_data(self, cache_file_paths): + """Cache feature data to file and delete from memory + + Parameters + ---------- + cache_file_paths : str | None + Path to file for saving feature data + """ + self._cache_data( + self.data, self.features, cache_file_paths, self.overwrite_cache + ) + + @property + def requested_shape(self): + """Get requested shape for cached data""" + shape = get_raster_shape(self.raster_index) + requested_shape = ( + shape[0] // self.hr_spatial_coarsen, + shape[1] // self.hr_spatial_coarsen, + len(self.raw_time_index[self.temporal_slice]), + len(self.features)) + return requested_shape + + def load_cached_data(self, with_split=True): + """Load data from cache files and split into training and validation + + Parameters + ---------- + with_split : bool + Whether to split into training and validation data or not. + """ + if self.data is not None: + logger.info('Called load_cached_data() but self.data is not None') + + elif self.data is None: + msg = ( + 'Found {} cache files but need {} for features {}! ' + 'These are the cache files that were found: {}'.format( + len(self.cache_files), + len(self.features), + self.features, + self.cache_files)) + assert len(self.cache_files) == len(self.features), msg + + self.data = np.full(shape=self.requested_shape, fill_value=np.nan, + dtype=np.float32) + + logger.info(f'Loading cached data from: {self.cache_files}') + max_workers = self.load_workers + self._load_cached_data(data=self.data, + cache_files=self.cache_files, + features=self.features, + max_workers=max_workers) + + self.time_index = self.raw_time_index[self.temporal_slice] + + nan_perc = 100 * np.isnan(self.data).sum() / self.data.size + if nan_perc > 0: + msg = 'Data has {:.2f}% NaN values!'.format(nan_perc) + logger.warning(msg) + warnings.warn(msg) + + logger.debug( + 'Splitting data into training / validation sets ' + f'({1 - self.val_split}, {self.val_split}) ' + f'for {self.input_file_info}') + + if with_split: + self.data, self.val_data = self.split_data( + val_split=self.val_split, shuffle_time=self.shuffle_time) + + def run_all_data_init(self): + """Build base 4D data array. Can handle multiple files but assumes + each file has the same spatial domain + + Returns + ------- + data : np.ndarray + 4D array of high res data + (spatial_1, spatial_2, temporal, features) + """ + now = dt.now() + logger.debug(f'Loading data for raster of shape {self.grid_shape}') + + # get the file-native time index without pruning + if self.is_time_independent: + n_steps = 1 + shifted_time_chunks = [slice(None)] + else: + n_steps = len(self.raw_time_index[self.temporal_slice]) + shifted_time_chunks = get_chunk_slices( + n_steps, self.time_chunk_size) + + self.run_data_extraction() + self.run_data_compute() + + logger.info('Building final data array') + self.parallel_data_fill(shifted_time_chunks, self.extract_workers) + + if self.invert_lat: + self.data = self.data[::-1] + + if self.time_roll != 0: + logger.debug('Applying time roll to data array') + self.data = np.roll(self.data, self.time_roll, axis=2) + + if self.hr_spatial_coarsen > 1: + logger.debug('Applying hr spatial coarsening to data array') + self.data = spatial_coarsening( + self.data, s_enhance=self.hr_spatial_coarsen, obs_axis=False) + self.lat_lon = spatial_coarsening( + self.lat_lon, s_enhance=self.hr_spatial_coarsen, + obs_axis=False) + if self.load_cached: + for f in self.cached_features: + f_index = self.features.index(f) + logger.info(f'Loading {f} from {self.cache_files[f_index]}') + with open(self.cache_files[f_index], 'rb') as fh: + self.data[..., f_index] = pickle.load(fh) + + logger.info(f'Finished extracting data for {self.input_file_info} in ' + f'{dt.now() - now}') + return self.data + + def run_data_extraction(self): + """Run the raw dataset extraction process from disk to raw + un-manipulated datasets. + """ + if self.extract_features: + logger.info( + f'Starting extraction of {self.extract_features} ' + f'using {len(self.time_chunks)} time_chunks.') + if self.extract_workers == 1: + self._raw_data = self.serial_extract(self.file_paths, + self.raster_index, + self.time_chunks, + self.extract_features, + **self.res_kwargs) + + else: + self._raw_data = self.parallel_extract(self.file_paths, + self.raster_index, + self.time_chunks, + self.extract_features, + self.extract_workers, + **self.res_kwargs) + + logger.info( + f'Finished extracting {self.extract_features} for ' + f'{self.input_file_info}' + ) + + def run_data_compute(self): + """Run the data computation / derivation from raw features to desired + features. + """ + if self.derive_features: + logger.info(f'Starting computation of {self.derive_features}') + + if self.compute_workers == 1: + self._raw_data = self.serial_compute(self._raw_data, + self.file_paths, + self.raster_index, + self.time_chunks, + self.derive_features, + self.noncached_features, + self.handle_features) + + elif self.compute_workers != 1: + self._raw_data = self.parallel_compute(self._raw_data, + self.file_paths, + self.raster_index, + self.time_chunks, + self.derive_features, + self.noncached_features, + self.handle_features, + self.compute_workers) + + logger.info( + f'Finished computing {self.derive_features} for ' + f'{self.input_file_info}') + + def data_fill(self, t, t_slice, f_index, f): + """Place single extracted / computed chunk in final data array + + Parameters + ---------- + t : int + Index of time slice in extracted / computed raw data dictionary + t_slice : slice + Time slice corresponding to the location in the final data array + f_index : int + Index of feature in the final data array + f : str + Name of corresponding feature in the raw data dictionary + """ + tmp = self._raw_data[t][f] + if len(tmp.shape) == 2: + tmp = tmp[..., np.newaxis] + self.data[..., t_slice, f_index] = tmp + + def serial_data_fill(self, shifted_time_chunks): + """Fill final data array in serial + + Parameters + ---------- + shifted_time_chunks : list + List of time slices corresponding to the appropriate location of + extracted / computed chunks in the final data array + """ + for t, ts in enumerate(shifted_time_chunks): + for _, f in enumerate(self.noncached_features): + f_index = self.features.index(f) + self.data_fill(t, ts, f_index, f) + interval = int(np.ceil(len(shifted_time_chunks) / 10)) + if t % interval == 0: + logger.info( + f'Added {t + 1} of {len(shifted_time_chunks)} ' + 'chunks to final data array') + self._raw_data.pop(t) + + def parallel_data_fill(self, shifted_time_chunks, max_workers=None): + """Fill final data array with extracted / computed chunks + + Parameters + ---------- + shifted_time_chunks : list + List of time slices corresponding to the appropriate location of + extracted / computed chunks in the final data array + max_workers : int | None + Max number of workers to use for building final data array. If None + max available workers will be used. If 1 cached data will be loaded + in serial + """ + self.data = np.zeros((self.grid_shape[0], + self.grid_shape[1], + self.n_tsteps, + len(self.features)), + dtype=np.float32) + + if max_workers == 1: + self.serial_data_fill(shifted_time_chunks) + else: + with ThreadPoolExecutor(max_workers=max_workers) as exe: + futures = {} + now = dt.now() + for t, ts in enumerate(shifted_time_chunks): + for _, f in enumerate(self.noncached_features): + f_index = self.features.index(f) + future = exe.submit(self.data_fill, t, ts, f_index, f) + futures[future] = {'t': t, 'fidx': f_index} + + logger.info( + f'Started adding {len(futures)} chunks ' + f'to data array in {dt.now() - now}.') + + interval = int(np.ceil(len(futures) / 10)) + for i, future in enumerate(as_completed(futures)): + try: + future.result() + except Exception as e: + msg = ( + f'Error adding ({futures[future]["t"]}, ' + f'{futures[future]["fidx"]}) chunk to ' + 'final data array.') + logger.exception(msg) + raise RuntimeError(msg) from e + if i % interval == 0: + logger.debug( + f'Added {i+1} out of {len(futures)} ' + 'chunks to final data array') + logger.info('Finished building data array') + + @abstractmethod + def get_raster_index(self): + """Get raster index for file data. Here we assume the list of paths in + file_paths all have data with the same spatial domain. We use the first + file in the list to compute the raster + + Returns + ------- + raster_index : np.ndarray + 2D array of grid indices for H5 or list of + slices for NETCDF + """ + + def lin_bc(self, bc_files, threshold=0.1): + """Bias correct the data in this DataHandler using linear bias + correction factors from files output by MonthlyLinearCorrection or + LinearCorrection from sup3r.bias.bias_calc + + Parameters + ---------- + bc_files : list | tuple | str + One or more filepaths to .h5 files output by + MonthlyLinearCorrection or LinearCorrection. These should contain + datasets named "{feature}_scalar" and "{feature}_adder" where + {feature} is one of the features contained by this DataHandler and + the data is a 3D array of shape (lat, lon, time) where time is + length 1 for annual correction or 12 for monthly correction. + threshold : float + Nearest neighbor euclidean distance threshold. If the DataHandler + coordinates are more than this value away from the bias correction + lat/lon, an error is raised. + """ + + if isinstance(bc_files, str): + bc_files = [bc_files] + + completed = [] + for idf, feature in enumerate(self.features): + for fp in bc_files: + if feature not in completed: + scalar, adder = get_spatial_bc_factors( + lat_lon=self.lat_lon, + feature_name=feature, + bias_fp=fp, + threshold=threshold) + + if scalar.shape[-1] == 1: + scalar = np.repeat(scalar, self.shape[2], axis=2) + adder = np.repeat(adder, self.shape[2], axis=2) + elif scalar.shape[-1] == 12: + idm = self.time_index.month.values - 1 + scalar = scalar[..., idm] + adder = adder[..., idm] + else: + msg = ( + 'Can only accept bias correction factors ' + 'with last dim equal to 1 or 12 but ' + 'received bias correction factors with ' + 'shape {}'.format(scalar.shape)) + logger.error(msg) + raise RuntimeError(msg) + + logger.info( + 'Bias correcting "{}" with linear ' + 'correction from "{}"'.format( + feature, os.path.basename(fp))) + self.data[..., idf] *= scalar + self.data[..., idf] += adder + completed.append(feature) + + +# pylint: disable=W0223 +class DataHandlerDC(DataHandler): + """Data-centric data handler""" + + def get_observation_index(self, temporal_weights=None, + spatial_weights=None): + """Randomly gets weighted spatial sample and time sample + + Parameters + ---------- + temporal_weights : array + Weights used to select time slice + (n_time_chunks) + spatial_weights : array + Weights used to select spatial chunks + (n_lat_chunks * n_lon_chunks) + + Returns + ------- + observation_index : tuple + Tuple of sampled spatial grid, time slice, and features indices. + Used to get single observation like self.data[observation_index] + """ + if spatial_weights is not None: + spatial_slice = weighted_box_sampler( + self.data, self.sample_shape[:2], weights=spatial_weights) + else: + spatial_slice = uniform_box_sampler( + self.data, self.sample_shape[:2]) + if temporal_weights is not None: + temporal_slice = weighted_time_sampler( + self.data, self.sample_shape[2], weights=temporal_weights) + else: + temporal_slice = uniform_time_sampler( + self.data, self.sample_shape[2]) + + return tuple( + [*spatial_slice, temporal_slice, np.arange(len(self.features))]) + + def get_next(self, temporal_weights=None, spatial_weights=None): + """Get data for observation using weighted random observation index. + Loops repeatedly over randomized time index. + + Parameters + ---------- + temporal_weights : array + Weights used to select time slice + (n_time_chunks) + spatial_weights : array + Weights used to select spatial chunks + (n_lat_chunks * n_lon_chunks) + + Returns + ------- + observation : np.ndarray + 4D array + (spatial_1, spatial_2, temporal, features) + """ + self.current_obs_index = self.get_observation_index( + temporal_weights=temporal_weights, spatial_weights=spatial_weights) + observation = self.data[self.current_obs_index] + return observation + + @classmethod + def get_file_times(cls, file_paths, **kwargs): + """Get time index from data files + + Parameters + ---------- + file_paths : list + path to data file + kwargs : dict + kwargs passed to source handler for data extraction. e.g. This + could be {'parallel': True, + 'chunks': {'south_north': 120, 'west_east': 120}} + which then gets passed to xr.open_mfdataset(file, **kwargs) + + Returns + ------- + time_index : pd.Datetimeindex + List of times as a Datetimeindex + """ + handle = cls.source_handler(file_paths, **kwargs) + + if hasattr(handle, 'Times'): + time_index = np_to_pd_times(handle.Times.values) + elif hasattr(handle, 'indexes') and 'time' in handle.indexes: + time_index = handle.indexes['time'] + if not isinstance(time_index, pd.DatetimeIndex): + time_index = time_index.to_datetimeindex() + elif hasattr(handle, 'times'): + time_index = np_to_pd_times(handle.times.values) + else: + msg = ( + f'Could not get time_index for {file_paths}. ' + 'Assuming time independence.') + time_index = None + logger.warning(msg) + warnings.warn(msg) + + return time_index + + @classmethod + def get_time_index(cls, file_paths, max_workers=None, **kwargs): + """Get time index from data files + + Parameters + ---------- + file_paths : list + path to data file + max_workers : int | None + Max number of workers to use for parallel time index building + kwargs : dict + kwargs passed to source handler for data extraction. e.g. This + could be {'parallel': True, + 'chunks': {'south_north': 120, 'west_east': 120}} + which then gets passed to xr.open_mfdataset(file, **kwargs) + + Returns + ------- + time_index : pd.Datetimeindex + List of times as a Datetimeindex + """ + max_workers = (len(file_paths) + if max_workers is None + else np.min((max_workers, len(file_paths)))) + if max_workers == 1: + return cls.get_file_times(file_paths, **kwargs) + ti = {} + with ThreadPoolExecutor(max_workers=max_workers) as exe: + futures = {} + now = dt.now() + for i, f in enumerate(file_paths): + future = exe.submit(cls.get_file_times, [f], **kwargs) + futures[future] = {'idx': i, 'file': f} + + logger.info( + f'Started building time index from {len(file_paths)} ' + f'files in {dt.now() - now}.') + + for i, future in enumerate(as_completed(futures)): + try: + val = future.result() + if val is not None: + ti[futures[future]['idx']] = list(val) + except Exception as e: + msg = ('Error while getting time index from file ' + f'{futures[future]["file"]}.') + logger.exception(msg) + raise RuntimeError(msg) from e + logger.debug(f'Stored {i+1} out of {len(futures)} file times') + times = np.concatenate(list(ti.values())) + return pd.DatetimeIndex(sorted(set(times))) + + @classmethod + def feature_registry(cls): + """Registry of methods for computing features + + Returns + ------- + dict + Method registry + """ + registry = {'BVF2_(.*)m': BVFreqSquaredNC, + 'BVF_MO_(.*)m': BVFreqMon, + 'RMOL': InverseMonNC, + 'U_(.*)': UWind, + 'V_(.*)': VWind, + 'Windspeed_(.*)m': WindspeedNC, + 'Winddirection_(.*)m': WinddirectionNC, + 'lat_lon': LatLonNC, + 'Shear_(.*)m': Shear, + 'REWS_(.*)m': Rews, + 'Temperature_(.*)m': TempNC, + 'Pressure_(.*)m': PressureNC, + 'PotentialTemp_(.*)m': PotentialTempNC, + 'PT_(.*)m': PotentialTempNC, + 'topography': ['HGT', 'orog']} + return registry + + @classmethod + def extract_feature(cls, + file_paths, + raster_index, + feature, + time_slice=slice(None), + **kwargs): + """Extract single feature from data source. The requested feature + can match exactly to one found in the source data or can have a + matching prefix with a suffix specifying the height or pressure level + to interpolate to. e.g. feature=U_100m -> interpolate exact match U to + 100 meters. + + Parameters + ---------- + file_paths : list + path to data file + raster_index : ndarray + Raster index array + feature : str + Feature to extract from data + time_slice : slice + slice of time to extract + kwargs : dict + kwargs passed to source handler for data extraction. e.g. This + could be {'parallel': True, + 'chunks': {'south_north': 120, 'west_east': 120}} + which then gets passed to xr.open_mfdataset(file, **kwargs) + + Returns + ------- + ndarray + Data array for extracted feature + (spatial_1, spatial_2, temporal) + """ + logger.debug( + f'Extracting {feature} with time_slice={time_slice}, ' + f'raster_index={raster_index}, kwargs={kwargs}.' + ) + handle = cls.source_handler(file_paths, **kwargs) + f_info = Feature(feature, handle) + interp_height = f_info.height + interp_pressure = f_info.pressure + basename = f_info.basename + + if feature in handle or feature.lower() in handle: + feat_key = feature if feature in handle else feature.lower() + fdata = cls.direct_extract( + handle, feat_key, raster_index, time_slice) + + elif basename in handle or basename.lower() in handle: + feat_key = basename if basename in handle else basename.lower() + if interp_height is not None: + fdata = Interpolator.interp_var_to_height( + handle, + feat_key, + raster_index, + np.float32(interp_height), + time_slice) + elif interp_pressure is not None: + fdata = Interpolator.interp_var_to_pressure( + handle, + feat_key, + raster_index, + np.float32(interp_pressure), + time_slice) + + else: + msg = f'{feature} cannot be extracted from source data.' + logger.exception(msg) + raise ValueError(msg) + + fdata = np.transpose(fdata, (1, 2, 0)) + return fdata.astype(np.float32) + + @classmethod + def direct_extract(cls, handle, feature, raster_index, time_slice): + """Extract requested feature directly from source data, rather than + interpolating to a requested height or pressure level + + Parameters + ---------- + handle : xarray + netcdf data object + feature : str + Name of feature to extract directly from source handler + raster_index : list + List of slices for raster index of spatial domain + time_slice : slice + slice of time to extract + + Returns + ------- + fdata : ndarray + Data array for requested feature + """ + # Sometimes xarray returns fields with (Times, time, lats, lons) + # with a single entry in the 'time' dimension so we include this [0] + if len(handle[feature].dims) == 4: + idx = tuple([time_slice, 0, *raster_index]) + elif len(handle[feature].dims) == 3: + idx = tuple([time_slice, *raster_index]) + else: + idx = tuple(raster_index) + fdata = np.array(handle[feature][idx], dtype=np.float32) + if len(fdata.shape) == 2: + fdata = np.expand_dims(fdata, axis=0) + return fdata + + @classmethod + def get_full_domain(cls, file_paths): + """Get full shape and min available lat lon. To simplify processing + of full domain without needing to specify target and shape. + + Parameters + ---------- + file_paths : list + List of data file paths + + Returns + ------- + target : tuple + (lat, lon) for lower left corner + lat_lon : ndarray + Raw lat/lon array for entire domain + """ + return cls.get_lat_lon(file_paths, [slice(None), slice(None)]) + + @staticmethod + def get_closest_lat_lon(lat_lon, target): + """Get closest indices to target lat lon to use for lower left corner + of raster index + + Parameters + ---------- + lat_lon : ndarray + Array of lat/lon + (spatial_1, spatial_2, 2) + Last dimension in order of (lat, lon) + target : tuple + (lat, lon) for lower left corner + + Returns + ------- + row : int + row index for closest lat/lon to target lat/lon + col : int + col index for closest lat/lon to target lat/lon + """ + # shape of ll2 is (n, 2) where axis=1 is (lat, lon) + ll2 = np.vstack( + (lat_lon[..., 0].flatten(), lat_lon[..., 1].flatten())).T + tree = KDTree(ll2) + _, i = tree.query(np.array(target)) + row, col = np.where( + (lat_lon[..., 0] == ll2[i, 0]) & (lat_lon[..., 1] == ll2[i, 1])) + row = row[0] + col = col[0] + return row, col + + @classmethod + def compute_raster_index(cls, file_paths, target, grid_shape): + """Get raster index for a given target and shape + + Parameters + ---------- + file_paths : list + List of input data file paths + target : tuple + Target coordinate for lower left corner of extracted data + grid_shape : tuple + Shape out extracted data + + Returns + ------- + list + List of slices corresponding to extracted data region + """ + lat_lon = cls.get_lat_lon( + file_paths[:1], [slice(None), slice(None)], invert_lat=False + ) + cls._check_grid_extent(target, grid_shape, lat_lon) + + row, col = cls.get_closest_lat_lon(lat_lon, target) + + closest = tuple(lat_lon[row, col]) + logger.debug(f'Found closest coordinate {closest} to target={target}') + if np.hypot(closest[0] - target[0], closest[1] - target[1]) > 1: + msg = 'Closest coordinate to target is more than 1 degree away' + logger.warning(msg) + warnings.warn(msg) + + if cls.lats_are_descending(lat_lon): + row_end = row + 1 + row_start = row_end - grid_shape[0] + else: + row_end = row + grid_shape[0] + row_start = row + raster_index = [slice(row_start, row_end), + slice(col, col + grid_shape[1])] + cls._validate_raster_shape(target, grid_shape, lat_lon, raster_index) + return raster_index + + @classmethod + def _check_grid_extent(cls, target, grid_shape, lat_lon): + """Make sure the requested target coordinate lies within the available + lat/lon grid. + + Parameters + ---------- + target : tuple + Target coordinate for lower left corner of extracted data + grid_shape : tuple + Shape out extracted data + lat_lon : ndarray + Array of lat/lon coordinates for entire available grid. Used to + check whether computed raster only includes coordinates within this + grid. + """ + min_lat = np.min(lat_lon[..., 0]) + min_lon = np.min(lat_lon[..., 1]) + max_lat = np.max(lat_lon[..., 0]) + max_lon = np.max(lat_lon[..., 1]) + logger.debug( + 'Calculating raster index from WRF file ' + f'for shape {grid_shape} and target {target}') + logger.debug( + f'lat/lon (min, max): {min_lat}/{min_lon}, ' + f'{max_lat}/{max_lon}') + msg = ( + f'target {target} out of bounds with min lat/lon ' + f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}') + assert (min_lat <= target[0] <= max_lat + and min_lon <= target[1] <= max_lon), msg + + @classmethod + def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): + """Make sure the computed raster_index only includes coordinates within + the available grid + + Parameters + ---------- + target : tuple + Target coordinate for lower left corner of extracted data + grid_shape : tuple + Shape out extracted data + lat_lon : ndarray + Array of lat/lon coordinates for entire available grid. Used to + check whether computed raster only includes coordinates within this + grid. + raster_index : list + List of slices selecting region from entire available grid. + """ + check = (raster_index[0].stop > lat_lon.shape[0] + or raster_index[1].stop > lat_lon.shape[1] + or raster_index[0].start < 0 + or raster_index[1].start < 0) + if check: + msg = ( + f'Invalid target {target}, shape {grid_shape}, and raster ' + f'{raster_index} for data domain of size ' + f'{lat_lon.shape[:-1]} with lower left corner ' + f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' + f' and upper right corner ({np.max(lat_lon[..., 0])}, ' + f'{np.max(lat_lon[..., 1])}).') + raise ValueError(msg) + + def get_raster_index(self): + """Get raster index for file data. Here we assume the list of paths in + file_paths all have data with the same spatial domain. We use the first + file in the list to compute the raster. + + Returns + ------- + raster_index : np.ndarray + 2D array of grid indices + """ + self.raster_file = ( + self.raster_file + if self.raster_file is None + else self.raster_file.replace('.txt', '.npy')) + if self.raster_file is not None and os.path.exists(self.raster_file): + logger.debug( + f'Loading raster index: {self.raster_file} ' + f'for {self.input_file_info}') + raster_index = np.load(self.raster_file, allow_pickle=True) + raster_index = list(raster_index) + else: + check = self.grid_shape is not None and self.target is not None + msg = ( + 'Must provide raster file or shape + target to get ' + 'raster index') + assert check, msg + raster_index = self.compute_raster_index( + self.file_paths, self.target, self.grid_shape) + logger.debug( + 'Found raster index with row, col slices: {}'.format( + raster_index)) + + if self.raster_file is not None: + basedir = os.path.dirname(self.raster_file) + if not os.path.exists(basedir): + os.makedirs(basedir) + logger.debug(f'Saving raster index: {self.raster_file}') + np.save(self.raster_file.replace('.txt', '.npy'), raster_index) + + return raster_index diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py new file mode 100644 index 000000000..70bbd4af8 --- /dev/null +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -0,0 +1,475 @@ +"""Dual data handler class for using separate low_res and high_res datasets""" +import logging +from warnings import warn + +import numpy as np +import pandas as pd + +from sup3r.preprocessing.data_handling.mixin import ( + CacheHandlingMixIn, + TrainingPrepMixIn, +) +from sup3r.utilities.regridder import Regridder +from sup3r.utilities.utilities import spatial_coarsening + +logger = logging.getLogger(__name__) + + +# pylint: disable=unsubscriptable-object +class DualDataHandler(CacheHandlingMixIn, TrainingPrepMixIn): + """Batch handling class for h5 data as high res (usually WTK) and netcdf + data as low res (usually ERA5)""" + + def __init__(self, + hr_handler, + lr_handler, + regrid_cache_pattern=None, + overwrite_regrid_cache=False, + regrid_workers=1, + load_cached=True, + shuffle_time=False, + s_enhance=15, + t_enhance=1, + val_split=0.0): + """Initialize data handler using hr and lr data handlers for h5 data + and nc data + + Parameters + ---------- + hr_handler : DataHandler + DataHandler for high_res data + lr_handler : DataHandler + DataHandler for low_res data + regrid_cache_pattern : str + Pattern for files to use for saving regridded ERA data. + overwrite_regrid_cache : bool + Whether to overwrite regrid cache + regrid_workers : int | None + Number of workers to use for regridding routine. + load_cached : bool + Whether to load cache to memory or wait until load_cached() + is called. + shuffle_time : bool + Whether to shuffle time indices prior to training/validation split + s_enhance : int + Spatial enhancement factor + t_enhance : int + Temporal enhancement factor + val_split : float + Percentage of data to reserve for validation. + """ + self.s_enhance = s_enhance + self.t_enhance = t_enhance + self.lr_dh = lr_handler + self.hr_dh = hr_handler + self.regrid_cache_pattern = regrid_cache_pattern + self.overwrite_regrid_cache = overwrite_regrid_cache + self.val_split = val_split + self.current_obs_index = None + self.load_cached = load_cached + self.regrid_workers = regrid_workers + self.shuffle_time = shuffle_time + self._lr_lat_lon = None + self._hr_lat_lon = None + self.lr_data = None + self.hr_data = None + self.lr_val_data = None + self.hr_val_data = None + self.lr_time_index = None + self.hr_time_index = None + self.lr_val_time_index = None + self.hr_val_time_index = None + + if self.try_load and self.load_cached: + self.load_cached_data() + + if not self.try_load: + self.get_data() + + self._run_pair_checks(hr_handler, lr_handler) + + self.check_clear_data() + + logger.info('Finished initializing DualDataHandler.') + + def get_data(self): + """Check hr and lr shapes and trim hr data if needed to match required + relationship to lr shape based on enhancement factors. Then regrid lr + data and split hr and lr data into training and validation sets.""" + self._shape_check() + self.get_lr_data() + self._val_split_check() + + def _val_split_check(self): + """Check if val_split > 0 and split data into validation and training. + Make sure validation data is larger than sample_shape""" + + if self.hr_data is not None and self.val_split > 0.0: + n_val_obs = self.hr_data.shape[2] * (1 - self.val_split) + n_val_obs = int(self.t_enhance * (n_val_obs // self.t_enhance)) + train_indices, val_indices = self._split_data_indices( + self.hr_data, + n_val_obs=n_val_obs, + shuffle_time=self.shuffle_time) + self.hr_val_data = self.hr_data[:, :, val_indices, :] + self.hr_data = self.hr_data[:, :, train_indices, :] + self.hr_val_time_index = self.hr_time_index[val_indices] + self.hr_time_index = self.hr_time_index[train_indices] + msg = ('High res validation data has shape=' + f'{self.hr_val_data.shape} and sample_shape=' + f'{self.hr_sample_shape}. Use a smaller sample_shape ' + 'and/or larger val_split.') + check = any(val_size < samp_size for val_size, samp_size in zip( + self.hr_val_data.shape, self.hr_sample_shape)) + if check: + logger.warning(msg) + warn(msg) + + if self.lr_data is not None and self.val_split > 0.0: + train_indices = list(set(train_indices // self.t_enhance)) + val_indices = list(set(val_indices // self.t_enhance)) + + self.lr_val_data = self.lr_data[:, :, val_indices, :] + self.lr_data = self.lr_data[:, :, train_indices, :] + + self.lr_val_time_index = self.lr_time_index[val_indices] + self.lr_time_index = self.lr_time_index[train_indices] + + msg = ('Low res validation data has shape=' + f'{self.lr_val_data.shape} and sample_shape=' + f'{self.lr_sample_shape}. Use a smaller sample_shape ' + 'and/or larger val_split.') + check = any(val_size < samp_size + for val_size, samp_size in zip( + self.lr_val_data.shape, self.lr_sample_shape)) + if check: + logger.warning(msg) + warn(msg) + + def normalize(self, means, stdevs): + """Normalize low_res data + + Parameters + ---------- + means : np.ndarray + dimensions (features) + array of means for all features with same ordering as data features + stdevs : np.ndarray + dimensions (features) + array of means for all features with same ordering as data features + """ + self._normalize(data=self.lr_data, + val_data=self.lr_val_data, + means=means, + stds=stdevs, + max_workers=self.lr_dh.norm_workers) + + @property + def output_features(self): + """Get list of output features. e.g. those that are returned by a + GAN""" + return self.hr_dh.output_features + + def _shape_check(self): + """Check if hr_handler.shape is divisible by s_enhance. If not take + the largest shape that can be.""" + + if self.hr_data is None: + logger.info("Loading high resolution cache.") + self.hr_dh.load_cached_data(with_split=False) + + msg = ('hr_handler.shape is not divisible by s_enhance. Using ' + f'shape = {self.hr_required_shape} instead.') + if self.hr_dh.shape[:-1] != self.hr_required_shape: + logger.warning(msg) + warn(msg) + + self.hr_data = self.hr_dh.data[:self.hr_required_shape[0], + :self.hr_required_shape[1], + :self.hr_required_shape[2]] + self.hr_time_index = self.hr_dh.time_index[:self.hr_required_shape[2]] + self.lr_time_index = self.lr_dh.time_index[:self.lr_required_shape[2]] + + assert np.array_equal(self.hr_time_index[::self.t_enhance].values, + self.lr_time_index) + + def _run_pair_checks(self, hr_handler, lr_handler): + """Run sanity checks on high_res and low_res pairs. The handler data + shapes are restricted by enhancement factors.""" + msg = ('Validation split is done by DualDataHandler. ' + 'hr_handler.val_split and lr_handler.val_split should both be ' + 'zero.') + assert hr_handler.val_split == 0 and lr_handler.val_split == 0, msg + msg = ('Handlers have incompatible number of features. ' + f'({hr_handler.features} vs {lr_handler.features})') + assert hr_handler.features == lr_handler.features, msg + hr_shape = hr_handler.sample_shape + lr_shape = (hr_shape[0] // self.s_enhance, + hr_shape[1] // self.s_enhance, + hr_shape[2] // self.t_enhance) + msg = (f'hr_handler.sample_shape {hr_handler.sample_shape} and ' + f'lr_handler.sample_shape {lr_handler.sample_shape} are ' + f'incompatible. Must be {hr_shape} and {lr_shape}.') + assert lr_handler.sample_shape == lr_shape, msg + + if hr_handler.data is not None and lr_handler.data is not None: + hr_shape = self.hr_data.shape + lr_shape = (hr_shape[0] // self.s_enhance, + hr_shape[1] // self.s_enhance, + hr_shape[2] // self.t_enhance, hr_shape[3]) + msg = (f'hr_data.shape {self.hr_data.shape} and ' + f'lr_data.shape {self.lr_data.shape} are ' + f'incompatible. Must be {hr_shape} and {lr_shape}.') + assert self.lr_data.shape == lr_shape, msg + + if self.lr_val_data is not None and self.hr_val_data is not None: + hr_shape = self.hr_val_data.shape + lr_shape = (hr_shape[0] // self.s_enhance, + hr_shape[1] // self.s_enhance, + hr_shape[2] // self.t_enhance, hr_shape[3]) + msg = (f'hr_val_data.shape {self.hr_val_data.shape} ' + f'and lr_val_data.shape {self.lr_val_data.shape}' + f' are incompatible. Must be {hr_shape} and {lr_shape}.') + assert self.lr_val_data.shape == lr_shape, msg + + @property + def grid_mem(self): + """Get memory used by a feature at a single time step + + Returns + ------- + int + Number of bytes for a single feature array at a single time step + """ + grid_mem = np.product(self.lr_grid_shape) + # assuming feature arrays are float32 (4 bytes) + return 4 * grid_mem + + @property + def feature_mem(self): + """Number of bytes for a single feature array. Used to estimate + max_workers. + + Returns + ------- + int + Number of bytes for a single feature array + """ + feature_mem = self.grid_mem * len(self.lr_time_index) + return feature_mem + + @property + def sample_shape(self): + """Get lr sample shape""" + return self.lr_dh.sample_shape + + @property + def lr_sample_shape(self): + """Get lr sample shape""" + return self.lr_dh.sample_shape + + @property + def hr_sample_shape(self): + """Get hr sample shape""" + return self.hr_dh.sample_shape + + @property + def features(self): + """Get list of features in each data handler""" + return self.hr_dh.features + + @property + def data(self): + """Get low res data. Same as self.lr_data but used to match property + used by batch handler""" + return self.lr_data + + @property + def lr_input_data(self): + """Get low res data used as input to regridding routine""" + if self.lr_dh.data is None: + self.lr_dh.load_cached_data() + return self.lr_dh.data[..., :self.lr_required_shape[2], :] + + @property + def shape(self): + """Get low_res shape""" + return self.lr_dh.shape + + @property + def lr_required_shape(self): + """Return required shape for regridded low_res data""" + return (self.hr_dh.requested_shape[0] // self.s_enhance, + self.hr_dh.requested_shape[1] // self.s_enhance, + self.hr_dh.requested_shape[2] // self.t_enhance) + + @property + def hr_required_shape(self): + """Return required shape for high_res data""" + return (self.s_enhance * self.lr_required_shape[0], + self.s_enhance * self.lr_required_shape[1], + self.t_enhance * self.lr_required_shape[2]) + + @property + def lr_grid_shape(self): + """Return grid shape for regridded low_res data""" + return (self.lr_required_shape[0], self.lr_required_shape[1]) + + @property + def lr_requested_shape(self): + """Return requested shape for low_res data""" + return (*self.lr_required_shape, len(self.features)) + + @property + def lr_lat_lon(self): + """Get low_res lat lon array""" + if self._lr_lat_lon is None: + self._lr_lat_lon = spatial_coarsening(self.hr_lat_lon, + s_enhance=self.s_enhance, + obs_axis=False) + return self._lr_lat_lon + + @lr_lat_lon.setter + def lr_lat_lon(self, lat_lon): + """Set low_res lat lon array""" + self._lr_lat_lon = lat_lon + + @property + def hr_lat_lon(self): + """Get high_res lat lon array""" + if self._hr_lat_lon is None: + self._hr_lat_lon = self.hr_dh.lat_lon[:self.hr_required_shape[0], : + self.hr_required_shape[1]] + return self._hr_lat_lon + + @hr_lat_lon.setter + def hr_lat_lon(self, lat_lon): + """Set high_res lat lon array""" + self._hr_lat_lon = lat_lon + + @property + def regrid_cache_files(self): + """Get file names of regridded cache data""" + cache_files = self._get_cache_file_names(self.regrid_cache_pattern, + grid_shape=self.lr_grid_shape, + time_index=self.lr_time_index, + target=self.hr_dh.target, + features=self.hr_dh.features) + return cache_files + + @property + def try_load(self): + """Check if we should try to load cached data""" + try_load = self._should_load_cache(self.regrid_cache_pattern, + self.regrid_cache_files, + self.overwrite_regrid_cache) + return try_load + + def load_lr_cached_data(self): + """Load low_res cache data""" + + regridded_data = np.full(shape=self.lr_requested_shape, + fill_value=np.nan, + dtype=np.float32) + + logger.info( + f'Loading cache with requested_shape={self.lr_requested_shape}.') + self._load_cached_data(regridded_data, + self.regrid_cache_files, + self.features, + max_workers=self.hr_dh.load_workers) + + self.lr_data = regridded_data + + def load_cached_data(self): + """Load regridded low_res and high_res cache data""" + self.load_lr_cached_data() + self._shape_check() + self._val_split_check() + + def check_clear_data(self): + """Check if data was cached and free memory if load_cached is False""" + if self.regrid_cache_pattern is not None and not self.load_cached: + self.lr_data = None + self.lr_val_data = None + self.hr_dh.check_clear_data() + + def get_lr_data(self): + """Check if era data is cached. If not then extract data and regrid. + Save to cache if cache pattern provided.""" + + if self.try_load: + self.load_lr_cached_data() + else: + regridded_data = self.regrid_lr_data() + + if self.regrid_cache_pattern is not None: + logger.info('Caching low resolution data with ' + f'shape={regridded_data.shape}.') + self._cache_data(regridded_data, + features=self.features, + cache_file_paths=self.regrid_cache_files, + overwrite=self.overwrite_regrid_cache) + self.lr_data = regridded_data + + def get_regridder(self): + """Get regridder object""" + input_meta = pd.DataFrame() + input_meta['latitude'] = self.lr_dh.lat_lon[..., 0].flatten() + input_meta['longitude'] = self.lr_dh.lat_lon[..., 1].flatten() + target_meta = pd.DataFrame() + target_meta['latitude'] = self.lr_lat_lon[..., 0].flatten() + target_meta['longitude'] = self.lr_lat_lon[..., 1].flatten() + return Regridder(input_meta, + target_meta, + max_workers=self.regrid_workers) + + def regrid_lr_data(self): + """Regrid low_res data for all requested features + + Returns + ------- + out : ndarray + Array of regridded low_res data with all features + (spatial_1, spatial_2, temporal, n_features) + """ + logger.info('Regridding low resolution feature data.') + regridder = self.get_regridder() + + out = [] + for i in range(len(self.features)): + tmp = regridder(self.lr_input_data[..., i]) + tmp = tmp.reshape(self.lr_required_shape)[..., np.newaxis] + out.append(tmp) + return np.concatenate(out, axis=-1) + + def get_next(self): + """Get next high_res + low_res. Gets random spatiotemporal sample for + h5 data and then uses enhancement factors to subsample + interpolated/regridded low_res data for same spatiotemporal extent. + + Returns + ------- + hr_data : ndarray + Array of high resolution data with each feature equal in shape to + hr_sample_shape + lr_data : ndarray + Array of low resolution data with each feature equal in shape to + lr_sample_shape + """ + lr_obs_idx = self._get_observation_index(self.lr_data, + self.lr_sample_shape) + hr_obs_idx = [] + for s in lr_obs_idx[:2]: + hr_obs_idx.append( + slice(s.start * self.s_enhance, s.stop * self.s_enhance)) + for s in lr_obs_idx[2:-1]: + hr_obs_idx.append( + slice(s.start * self.t_enhance, s.stop * self.t_enhance)) + hr_obs_idx.append(lr_obs_idx[-1]) + hr_obs_idx = tuple(hr_obs_idx) + self.current_obs_index = { + 'hr_index': hr_obs_idx, + 'lr_index': lr_obs_idx + } + return self.hr_data[hr_obs_idx], self.lr_data[lr_obs_idx] diff --git a/sup3r/preprocessing/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py similarity index 100% rename from sup3r/preprocessing/exogenous_data_handling.py rename to sup3r/preprocessing/data_handling/exogenous_data_handling.py diff --git a/sup3r/preprocessing/data_handling/h5_data_handling.py b/sup3r/preprocessing/data_handling/h5_data_handling.py new file mode 100644 index 000000000..2866b772e --- /dev/null +++ b/sup3r/preprocessing/data_handling/h5_data_handling.py @@ -0,0 +1,586 @@ +"""Data handling for H5 files. +@author: bbenton +""" + +import copy +import logging +import os + +import numpy as np +from rex import MultiFileNSRDBX, MultiFileWindX + +from sup3r.preprocessing.data_handling.base import DataHandler, DataHandlerDC +from sup3r.preprocessing.feature_handling import ( + BVFreqMon, + BVFreqSquaredH5, + ClearSkyRatioH5, + CloudMaskH5, + LatLonH5, + Rews, + TopoH5, + UWind, + VWind, +) +from sup3r.utilities.utilities import ( + daily_temporal_coarsening, + uniform_box_sampler, +) + +np.random.seed(42) + +logger = logging.getLogger(__name__) + + +class DataHandlerH5(DataHandler): + """DataHandler for H5 Data""" + + # the handler from rex to open h5 data. + REX_HANDLER = MultiFileWindX + + @classmethod + def source_handler(cls, file_paths, **kwargs): + """Rex data handler + + Note that xarray appears to treat open file handlers as singletons + within a threadpool, so its okay to open this source_handler without a + context handler or a .close() statement. + + Parameters + ---------- + file_paths : str | list + paths to data files + kwargs : dict + keyword arguments passed to source handler + + Returns + ------- + data : ResourceX + """ + return cls.REX_HANDLER(file_paths, **kwargs) + + @classmethod + def get_full_domain(cls, file_paths): + """Get target and shape for largest domain possible""" + msg = ( + 'You must either provide the target+shape inputs or an ' + 'existing raster_file input.' + ) + logger.error(msg) + raise ValueError(msg) + + @classmethod + def get_time_index(cls, file_paths, max_workers=None, **kwargs): + """Get time index from data files + + Parameters + ---------- + file_paths : list + path to data file + max_workers : int | None + placeholder to match signature + kwargs : dict + placeholder to match signature + + Returns + ------- + time_index : pd.DateTimeIndex + Time index from h5 source file(s) + """ + handle = cls.source_handler(file_paths) + time_index = handle.time_index + return time_index + + @classmethod + def feature_registry(cls): + """Registry of methods for computing features or extracting renamed + features + + Returns + ------- + dict + Method registry + """ + registry = { + 'BVF2_(.*)m': BVFreqSquaredH5, + 'BVF_MO_(.*)m': BVFreqMon, + 'U_(.*)m': UWind, + 'V_(.*)m': VWind, + 'lat_lon': LatLonH5, + 'REWS_(.*)m': Rews, + 'RMOL': 'inversemoninobukhovlength_2m', + 'P_(.*)m': 'pressure_(.*)m', + 'topography': TopoH5, + 'cloud_mask': CloudMaskH5, + 'clearsky_ratio': ClearSkyRatioH5, + } + return registry + + @classmethod + def extract_feature( + cls, + file_paths, + raster_index, + feature, + time_slice=slice(None), + **kwargs, + ): + """Extract single feature from data source + + Parameters + ---------- + file_paths : list + path to data file + raster_index : ndarray + Raster index array + feature : str + Feature to extract from data + time_slice : slice + slice of time to extract + kwargs : dict + keyword arguments passed to source handler + + Returns + ------- + ndarray + Data array for extracted feature + (spatial_1, spatial_2, temporal) + """ + logger.info(f'Extracting {feature} with kwargs={kwargs}') + handle = cls.source_handler(file_paths, **kwargs) + try: + fdata = handle[ + (feature, time_slice, *tuple([raster_index.flatten()])) + ] + except ValueError as e: + msg = f'{feature} cannot be extracted from source data' + logger.exception(msg) + raise ValueError(msg) from e + + fdata = fdata.reshape( + (-1, raster_index.shape[0], raster_index.shape[1]) + ) + fdata = np.transpose(fdata, (1, 2, 0)) + return fdata.astype(np.float32) + + def get_raster_index(self): + """Get raster index for file data. Here we assume the list of paths in + file_paths all have data with the same spatial domain. We use the first + file in the list to compute the raster. + + Returns + ------- + raster_index : np.ndarray + 2D array of grid indices + """ + if self.raster_file is not None and os.path.exists(self.raster_file): + logger.debug( + f'Loading raster index: {self.raster_file} ' + f'for {self.input_file_info}' + ) + raster_index = np.loadtxt(self.raster_file).astype(np.uint32) + else: + check = self.grid_shape is not None and self.target is not None + msg = ( + 'Must provide raster file or shape + target to get ' + 'raster index' + ) + assert check, msg + logger.debug( + 'Calculating raster index from WTK file ' + f'for shape {self.grid_shape} and target ' + f'{self.target}' + ) + handle = self.source_handler(self.file_paths[0]) + raster_index = handle.get_raster_index( + self.target, self.grid_shape, max_delta=self.max_delta + ) + if self.raster_file is not None: + basedir = os.path.dirname(self.raster_file) + if not os.path.exists(basedir): + os.makedirs(basedir) + logger.debug(f'Saving raster index: {self.raster_file}') + np.savetxt(self.raster_file, raster_index) + return raster_index + + +class DataHandlerH5WindCC(DataHandlerH5): + """Special data handling and batch sampling for h5 wtk or nsrdb data for + climate change applications""" + + # the handler from rex to open h5 data. + REX_HANDLER = MultiFileWindX + + # list of features / feature name patterns that are input to the generative + # model but are not part of the synthetic output and are not sent to the + # discriminator. These are case-insensitive and follow the Unix shell-style + # wildcard format. + TRAIN_ONLY_FEATURES = ( + 'temperature_max_*m', + 'temperature_min_*m', + 'relativehumidity_max_*m', + 'relativehumidity_min_*m', + ) + + def __init__(self, *args, **kwargs): + """ + Parameters + ---------- + *args : list + Same positional args as DataHandlerH5 + **kwargs : dict + Same keyword args as DataHandlerH5 + """ + sample_shape = kwargs.get('sample_shape', (10, 10, 24)) + t_shape = sample_shape[-1] + + if len(sample_shape) == 2: + logger.info( + 'Found 2D sample shape of {}. Adding spatial dim of 24'.format( + sample_shape + ) + ) + sample_shape = (*sample_shape, 24) + t_shape = sample_shape[-1] + kwargs['sample_shape'] = sample_shape + + if t_shape < 24 or t_shape % 24 != 0: + msg = ( + 'Climate Change DataHandler can only work with temporal ' + 'sample shapes that are one or more days of hourly data ' + '(e.g. 24, 48, 72...). The requested temporal sample ' + 'shape was: {}'.format(t_shape) + ) + logger.error(msg) + raise RuntimeError(msg) + + # validation splits not enabled for solar CC model. + kwargs['val_split'] = 0.0 + + super().__init__(*args, **kwargs) + + self.daily_data = None + self.daily_data_slices = None + self.run_daily_averages() + + def run_daily_averages(self): + """Calculate daily average data and store as attribute.""" + msg = ( + 'Data needs to be hourly with at least 24 hours, but data ' + 'shape is {}.'.format(self.data.shape) + ) + assert self.data.shape[2] % 24 == 0, msg + assert self.data.shape[2] > 24, msg + + n_data_days = int(self.data.shape[2] / 24) + daily_data_shape = ( + self.data.shape[0:2] + (n_data_days,) + (self.data.shape[3],) + ) + + logger.info( + 'Calculating daily average datasets for {} training ' + 'data days.'.format(n_data_days) + ) + + self.daily_data = np.zeros(daily_data_shape, dtype=np.float32) + + self.daily_data_slices = np.array_split( + np.arange(self.data.shape[2]), n_data_days + ) + self.daily_data_slices = [ + slice(x[0], x[-1] + 1) for x in self.daily_data_slices + ] + for idf, fname in enumerate(self.features): + for d, t_slice in enumerate(self.daily_data_slices): + if '_max_' in fname: + tmp = np.max(self.data[:, :, t_slice, idf], axis=2) + self.daily_data[:, :, d, idf] = tmp[:, :] + elif '_min_' in fname: + tmp = np.min(self.data[:, :, t_slice, idf], axis=2) + self.daily_data[:, :, d, idf] = tmp[:, :] + else: + tmp = daily_temporal_coarsening( + self.data[:, :, t_slice, idf], temporal_axis=2 + ) + self.daily_data[:, :, d, idf] = tmp[:, :, 0] + + logger.info( + 'Finished calculating daily average datasets for {} ' + 'training data days.'.format(n_data_days) + ) + + def _normalize_data(self, data, val_data, feature_index, mean, std): + """Normalize data with initialized mean and standard deviation for a + specific feature + + Parameters + ---------- + data : np.ndarray + Array of training data. + (spatial_1, spatial_2, temporal, n_features) + val_data : np.ndarray + Array of validation data. + (spatial_1, spatial_2, temporal, n_features) + feature_index : int + index of feature to be normalized + mean : float32 + specified mean of associated feature + std : float32 + specificed standard deviation for associated feature + """ + super()._normalize_data(data, val_data, feature_index, mean, std) + self.daily_data[..., feature_index] -= mean + self.daily_data[..., feature_index] /= std + + @classmethod + def feature_registry(cls): + """Registry of methods for computing features + + Returns + ------- + dict + Method registry + """ + registry = { + 'U_(.*)m': UWind, + 'V_(.*)m': VWind, + 'lat_lon': LatLonH5, + 'topography': TopoH5, + 'temperature_max_(.*)m': 'temperature_(.*)m', + 'temperature_min_(.*)m': 'temperature_(.*)m', + 'relativehumidity_max_(.*)m': 'relativehumidity_(.*)m', + 'relativehumidity_min_(.*)m': 'relativehumidity_(.*)m', + } + return registry + + def get_observation_index(self): + """Randomly gets spatial sample and time sample + + Returns + ------- + obs_ind_hourly : tuple + Tuple of sampled spatial grid, time slice, and features indices. + Used to get single observation like self.data[observation_index]. + This is for hourly high-res data slicing. + obs_ind_daily : tuple + Same as obs_ind_hourly but the temporal index (i=2) is a slice of + the daily data (self.daily_data) with day integers. + """ + spatial_slice = uniform_box_sampler(self.data, self.sample_shape[:2]) + + n_days = int(self.sample_shape[2] / 24) + rand_day_ind = np.random.choice(len(self.daily_data_slices) - n_days) + t_slice_0 = self.daily_data_slices[rand_day_ind] + t_slice_1 = self.daily_data_slices[rand_day_ind + n_days - 1] + t_slice_hourly = slice(t_slice_0.start, t_slice_1.stop) + t_slice_daily = slice(rand_day_ind, rand_day_ind + n_days) + + obs_ind_hourly = tuple( + [*spatial_slice, t_slice_hourly, np.arange(len(self.features))] + ) + + obs_ind_daily = tuple( + [*spatial_slice, t_slice_daily, np.arange(len(self.features))] + ) + + return obs_ind_hourly, obs_ind_daily + + def get_next(self): + """Get data for observation using random observation index. Loops + repeatedly over randomized time index + + Returns + ------- + obs_hourly : np.ndarray + 4D array + (spatial_1, spatial_2, temporal_hourly, features) + obs_daily_avg : np.ndarray + 4D array but the temporal axis is temporal_hourly//24 + (spatial_1, spatial_2, temporal_daily, features) + """ + obs_ind_hourly, obs_ind_daily = self.get_observation_index() + self.current_obs_index = obs_ind_hourly + obs_hourly = self.data[obs_ind_hourly] + obs_daily_avg = self.daily_data[obs_ind_daily] + return obs_hourly, obs_daily_avg + + def split_data(self, data=None, val_split=0.0, shuffle_time=False): + """Split time dimension into set of training indices and validation + indices. For NSRDB it makes sure that the splits happen at midnight. + + Parameters + ---------- + data : np.ndarray + 4D array of high res data + (spatial_1, spatial_2, temporal, features) + val_split : float + Fraction of data to separate for validation. + shuffle_time : bool + No effect. Used to fit base class function signature. + + Returns + ------- + data : np.ndarray + (spatial_1, spatial_2, temporal, features) + Training data fraction of initial data array. Initial data array is + overwritten by this new data array. + val_data : np.ndarray + (spatial_1, spatial_2, temporal, features) + Validation data fraction of initial data array. + """ + + if data is not None: + self.data = data + + midnight_ilocs = np.where( + (self.time_index.hour == 0) + & (self.time_index.minute == 0) + & (self.time_index.second == 0) + )[0] + + n_val_obs = int(np.ceil(val_split * len(midnight_ilocs))) + val_split_index = midnight_ilocs[n_val_obs] + + self.val_data = self.data[:, :, slice(None, val_split_index), :] + self.data = self.data[:, :, slice(val_split_index, None), :] + + self.val_time_index = self.time_index[slice(None, val_split_index)] + self.time_index = self.time_index[slice(val_split_index, None)] + + return self.data, self.val_data + + +class DataHandlerH5SolarCC(DataHandlerH5WindCC): + """Special data handling and batch sampling for h5 NSRDB solar data for + climate change applications""" + + # the handler from rex to open h5 data. + REX_HANDLER = MultiFileNSRDBX + + # list of features / feature name patterns that are input to the generative + # model but are not part of the synthetic output and are not sent to the + # discriminator. These are case-insensitive and follow the Unix shell-style + # wildcard format. + TRAIN_ONLY_FEATURES = ('U*', 'V*', 'topography') + + def __init__(self, *args, **kwargs): + """ + Parameters + ---------- + *args : list + Same positional args as DataHandlerH5 + **kwargs : dict + Same keyword args as DataHandlerH5 + """ + + args = copy.deepcopy(args) # safe copy for manipulation + required = ['ghi', 'clearsky_ghi', 'clearsky_ratio'] + missing = [dset for dset in required if dset not in args[1]] + if any(missing): + msg = ( + 'Cannot initialize DataHandlerH5SolarCC without required ' + 'features {}. All three are necessary to get the daily ' + 'average clearsky ratio (ghi sum / clearsky ghi sum), ' + 'even though only the clearsky ratio will be passed to the ' + 'GAN.'.format(required) + ) + logger.error(msg) + raise KeyError(msg) + + super().__init__(*args, **kwargs) + + @classmethod + def feature_registry(cls): + """Registry of methods for computing features + + Returns + ------- + dict + Method registry + """ + registry = { + 'U': UWind, + 'V': VWind, + 'windspeed': 'wind_speed', + 'winddirection': 'wind_direction', + 'lat_lon': LatLonH5, + 'cloud_mask': CloudMaskH5, + 'clearsky_ratio': ClearSkyRatioH5, + 'topography': TopoH5, + } + return registry + + def run_daily_averages(self): + """Calculate daily average data and store as attribute. + + Note that the H5 clearsky ratio feature requires special logic to match + the climate change dataset of daily average GHI / daily average CS_GHI. + This target climate change dataset is not equivalent to the average of + instantaneous hourly clearsky ratios + """ + + msg = ( + 'Data needs to be hourly with at least 24 hours, but data ' + 'shape is {}.'.format(self.data.shape) + ) + assert self.data.shape[2] % 24 == 0, msg + assert self.data.shape[2] > 24, msg + + n_data_days = int(self.data.shape[2] / 24) + daily_data_shape = ( + self.data.shape[0:2] + (n_data_days,) + (self.data.shape[3],) + ) + + logger.info( + 'Calculating daily average datasets for {} training ' + 'data days.'.format(n_data_days) + ) + + self.daily_data = np.zeros(daily_data_shape, dtype=np.float32) + + self.daily_data_slices = np.array_split( + np.arange(self.data.shape[2]), n_data_days + ) + self.daily_data_slices = [ + slice(x[0], x[-1] + 1) for x in self.daily_data_slices + ] + + i_ghi = self.features.index('ghi') + i_cs = self.features.index('clearsky_ghi') + i_ratio = self.features.index('clearsky_ratio') + + for d, t_slice in enumerate(self.daily_data_slices): + for idf in range(self.data.shape[-1]): + self.daily_data[:, :, d, idf] = daily_temporal_coarsening( + self.data[:, :, t_slice, idf], temporal_axis=2 + )[:, :, 0] + + # note that this ratio of daily irradiance sums is not the same as + # the average of hourly ratios. + total_ghi = np.nansum(self.data[:, :, t_slice, i_ghi], axis=2) + total_cs_ghi = np.nansum(self.data[:, :, t_slice, i_cs], axis=2) + avg_cs_ratio = total_ghi / total_cs_ghi + self.daily_data[:, :, d, i_ratio] = avg_cs_ratio + + # remove ghi and clearsky ghi from feature set. These shouldn't be used + # downstream for solar cc and keeping them confuses the batch handler + logger.info( + 'Finished calculating daily average clearsky_ratio, ' + 'removing ghi and clearsky_ghi from the ' + 'DataHandlerH5SolarCC feature list.' + ) + ifeats = np.array( + [i for i in range(len(self.features)) if i not in (i_ghi, i_cs)] + ) + self.data = self.data[..., ifeats] + self.daily_data = self.daily_data[..., ifeats] + self.features.remove('ghi') + self.features.remove('clearsky_ghi') + + logger.info( + 'Finished calculating daily average datasets for {} ' + 'training data days.'.format(n_data_days) + ) + + +class DataHandlerDCforH5(DataHandlerH5, DataHandlerDC): + """Data centric data handler for H5 files""" diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py new file mode 100644 index 000000000..3019d3b05 --- /dev/null +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -0,0 +1,1101 @@ +"""MixIn classes for data handling. +@author: bbenton +""" + +import glob +import logging +import os +import pickle +import warnings +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime as dt + +import numpy as np +import pandas as pd +from scipy.stats import mode + +from sup3r.utilities.utilities import ( + get_source_type, + ignore_case_path_fetch, + uniform_box_sampler, + uniform_time_sampler, +) + +np.random.seed(42) + +logger = logging.getLogger(__name__) + + +class CacheHandlingMixIn: + """Collection of methods for handling data caching and loading""" + + def _get_timestamp_0(self, time_index): + """Get a string timestamp for the first time index value with the + format YYYYMMDDHHMMSS""" + + time_stamp = time_index[0] + yyyy = str(time_stamp.year) + mm = str(time_stamp.month).zfill(2) + dd = str(time_stamp.day).zfill(2) + hh = str(time_stamp.hour).zfill(2) + min = str(time_stamp.minute).zfill(2) + ss = str(time_stamp.second).zfill(2) + ts0 = yyyy + mm + dd + hh + min + ss + return ts0 + + def _get_timestamp_1(self, time_index): + """Get a string timestamp for the last time index value with the + format YYYYMMDDHHMMSS""" + + time_stamp = time_index[-1] + yyyy = str(time_stamp.year) + mm = str(time_stamp.month).zfill(2) + dd = str(time_stamp.day).zfill(2) + hh = str(time_stamp.hour).zfill(2) + min = str(time_stamp.minute).zfill(2) + ss = str(time_stamp.second).zfill(2) + ts1 = yyyy + mm + dd + hh + min + ss + return ts1 + + def _get_cache_pattern(self, cache_pattern): + """Get correct cache file pattern for formatting. + + Returns + ------- + cache_pattern : str + The cache file pattern with formatting keys included. + """ + if cache_pattern is not None: + if '.pkl' not in cache_pattern: + cache_pattern += '.pkl' + if '{feature}' not in cache_pattern: + cache_pattern = cache_pattern.replace('.pkl', '_{feature}.pkl') + return cache_pattern + + def _get_cache_file_names( + self, + cache_pattern, + grid_shape, + time_index, + target, + features, + ): + """Get names of cache files from cache_pattern and feature names + + Parameters + ---------- + cache_pattern : str + Pattern to use for cache file names + grid_shape : tuple + Shape of grid to use for cache file naming + time_index : list | pd.DatetimeIndex + Time index to use for cache file naming + target : tuple + Target to use for cache file naming + features : list + List of features to use for cache file naming + + Returns + ------- + list + List of cache file names + """ + cache_pattern = self._get_cache_pattern(cache_pattern) + if cache_pattern is not None: + if '{feature}' not in cache_pattern: + cache_pattern = '{feature}_' + cache_pattern + cache_files = [ + cache_pattern.replace('{feature}', f.lower()) for f in features + ] + for i, f in enumerate(cache_files): + if '{shape}' in f: + shape = f'{grid_shape[0]}x{grid_shape[1]}' + shape += f'x{len(time_index)}' + f = f.replace('{shape}', shape) + if '{target}' in f: + target_str = f'{target[0]:.2f}_{target[1]:.2f}' + f = f.replace('{target}', target_str) + if '{times}' in f: + ts_0 = self._get_timestamp_0(time_index) + ts_1 = self._get_timestamp_1(time_index) + times = f'{ts_0}_{ts_1}' + f = f.replace('{times}', times) + + cache_files[i] = f + + for i, fp in enumerate(cache_files): + fp_check = ignore_case_path_fetch(fp) + if fp_check is not None: + cache_files[i] = fp_check + else: + cache_files = None + + return cache_files + + def _cache_data(self, data, features, cache_file_paths, overwrite=False): + """Cache feature data to files + + Parameters + ---------- + data : ndarray + Array of feature data to save to cache files + features : list + List of feature names. + cache_file_paths : str | None + Path to file for saving feature data + overwrite : bool + Whether to overwrite exisiting files. + """ + os.makedirs(os.path.dirname(cache_file_paths[0]), exist_ok=True) + for i, fp in enumerate(cache_file_paths): + if not os.path.exists(fp) or overwrite: + if overwrite and os.path.exists(fp): + logger.info( + f'Overwriting {features[i]} with shape ' + f'{data[..., i].shape} to {fp}' + ) + else: + logger.info( + f'Saving {features[i]} with shape ' + f'{data[..., i].shape} to {fp}' + ) + + tmp_file = fp.replace('.pkl', '.pkl.tmp') + with open(tmp_file, 'wb') as fh: + pickle.dump(data[..., i], fh, protocol=4) + os.replace(tmp_file, fp) + else: + msg = ( + f'Called cache_data but {fp} already exists. Set to ' + 'overwrite_cache to True to overwrite.' + ) + logger.warning(msg) + warnings.warn(msg) + + def _load_single_cached_feature( + self, fp, cache_files, features, required_shape + ): + """Load single feature from given file + + Parameters + ---------- + fp : string + File path for feature cache file + cache_files : list + List of cache files for each feature + features : list + List of requested features + required_shape : tuple + Required shape for full array of feature data + + Returns + ------- + out : ndarray + Array of data for given feature file. + + Raises + ------ + RuntimeError + Error raised if shape conflicts with requested shape + """ + idx = cache_files.index(fp) + assert features[idx].lower() in fp.lower() + fp = ignore_case_path_fetch(fp) + logger.info(f'Loading {features[idx]} from ' f'{fp}.') + + out = None + with open(fp, 'rb') as fh: + out = np.array(pickle.load(fh), dtype=np.float32) + msg = ( + 'Data loaded from from cache file "{}" ' + 'could not be written to feature channel {} ' + 'of full data array of shape {}. ' + 'The cached data has the wrong shape {}.'.format( + fp, idx, required_shape, out.shape + ) + ) + assert out.shape == required_shape, msg + return out + + def _should_load_cache( + self, cache_pattern, cache_files, overwrite_cache=False + ): + """Check if we should load cached data""" + try_load = ( + cache_pattern is not None + and not overwrite_cache + and all(os.path.exists(fp) for fp in cache_files) + ) + return try_load + + def parallel_load(self, data, cache_files, features, max_workers=None): + """Load feature data in parallel + + Parameters + ---------- + data : ndarray + Array to fill with cached data + cache_files : list + List of cache files for each feature + features : list + List of requested features + max_workers : int | None + Max number of workers to use for parallel data loading. If None + the max number of available workers will be used. + """ + logger.info( + f'Loading {len(cache_files)} cache files with ' + f'max_workers={max_workers}.' + ) + with ThreadPoolExecutor(max_workers=max_workers) as exe: + futures = {} + now = dt.now() + for i, fp in enumerate(cache_files): + future = exe.submit( + self._load_single_cached_feature, + fp=fp, + cache_files=cache_files, + features=features, + required_shape=data.shape[:-1], + ) + futures[future] = {'idx': i, 'fp': os.path.basename(fp)} + + logger.info( + f'Started loading all {len(cache_files)} cache ' + f'files in {dt.now() - now}.' + ) + + for i, future in enumerate(as_completed(futures)): + try: + data[..., futures[future]['idx']] = future.result() + except Exception as e: + msg = ( + 'Error while loading ' + f'{cache_files[futures[future]["idx"]]}' + ) + logger.exception(msg) + raise RuntimeError(msg) from e + logger.debug( + f'{i+1} out of {len(futures)} cache files ' + f'loaded: {futures[future]["fp"]}' + ) + + def _load_cached_data(self, data, cache_files, features, max_workers=None): + """Load cached data to provided array + + Parameters + ---------- + data : ndarray + Array to fill with cached data + cache_files : list + List of cache files for each feature + features : list + List of requested features + required_shape : tuple + Required shape for full array of feature data + max_workers : int | None + Max number of workers to use for parallel data loading. If None + the max number of available workers will be used. + """ + if max_workers == 1: + for i, fp in enumerate(cache_files): + out = self._load_single_cached_feature( + fp, cache_files, features, data.shape[:-1] + ) + msg = ( + 'Data loaded from from cache file "{}" ' + 'could not be written to feature channel {} ' + 'of full data array of shape {}. ' + 'The cached data has the wrong shape {}.'.format( + fp, i, data[..., i].shape, out.shape + ) + ) + assert data[..., i].shape == out.shape, msg + data[..., i] = out + + else: + self.parallel_load( + data, + cache_files, + features, + max_workers=max_workers, + ) + + @classmethod + def check_cached_features( + cls, + features, + cache_files=None, + overwrite_cache=False, + load_cached=False, + ): + """Check which features have been cached and check flags to determine + whether to load or extract this features again + + Parameters + ---------- + features : list + list of features to extract + cache_files : list | None + Path to files with saved feature data + overwrite_cache : bool + Whether to overwrite cached files + load_cached : bool + Whether to load data from cache files + + Returns + ------- + list + List of features to extract. Might not include features which have + cache files. + """ + extract_features = [] + # check if any features can be loaded from cache + if cache_files is not None: + for i, f in enumerate(features): + check = ( + os.path.exists(cache_files[i]) + and f.lower() in cache_files[i].lower() + ) + if check: + if not overwrite_cache: + if load_cached: + msg = ( + f'{f} found in cache file {cache_files[i]}.' + ' Loading from cache instead of extracting ' + 'from source files' + ) + logger.info(msg) + else: + msg = ( + f'{f} found in cache file {cache_files[i]}.' + ' Call load_cached_data() or use ' + 'load_cached=True to load this data.' + ) + logger.info(msg) + else: + msg = ( + f'{cache_files[i]} exists but overwrite_cache ' + 'is set to True. Proceeding with extraction.' + ) + logger.info(msg) + extract_features.append(f) + else: + extract_features.append(f) + else: + extract_features = features + + return extract_features + + +class InputMixIn(CacheHandlingMixIn): + """MixIn class with properties and methods for handling the spatiotemporal + data domain to extract from source data.""" + + def __init__( + self, + target, + shape, + raster_file=None, + raster_index=None, + temporal_slice=slice(None, None, 1), + ): + """Provide properties of the spatiotemporal data domain + + Parameters + ---------- + target : tuple + (lat, lon) lower left corner of raster. Either need target+shape or + raster_file. + shape : tuple + (rows, cols) grid size. Either need target+shape or raster_file. + raster_file : str | None + File for raster_index array for the corresponding target and shape. + If specified the raster_index will be loaded from the file if it + exists or written to the file if it does not yet exist. If None and + raster_index is not provided raster_index will be calculated + directly. Either need target+shape, raster_file, or raster_index + input. + raster_index : list + List of tuples or slices. Used as an alternative to computing the + raster index from target+shape or loading the raster index from + file + temporal_slice : slice + Slice specifying extent and step of temporal extraction. e.g. + slice(start, stop, time_pruning). If equal to slice(None, None, 1) + the full time dimension is selected. + """ + self.raster_file = raster_file + self.target = target + self.grid_shape = shape + self.raster_index = raster_index + self.temporal_slice = temporal_slice + self.lat_lon = None + self.overwrite_ti_cache = False + self.max_workers = None + self._ti_workers = None + self._raw_time_index = None + self._raw_tsteps = None + self._time_index = None + self._time_index_file = None + self._file_paths = None + self._cache_pattern = None + self._invert_lat = None + self._raw_lat_lon = None + self._full_raw_lat_lon = None + self._single_ts_files = None + self._worker_attrs = ['ti_workers'] + self.res_kwargs = {} + + @property + def raw_tsteps(self): + """Get number of time steps for all input files""" + if self._raw_tsteps is None: + if self.single_ts_files: + self._raw_tsteps = len(self.file_paths) + else: + self._raw_tsteps = len(self.raw_time_index) + return self._raw_tsteps + + @property + def single_ts_files(self): + """Check if there is a file for each time step, in which case we can + send a subset of files to the data handler according to ti_pad_slice""" + if self._single_ts_files is None: + logger.debug('Checking if input files are single timestep.') + t_steps = self.get_time_index(self.file_paths[:1], max_workers=1) + check = ( + len(self._file_paths) == len(self.raw_time_index) + and t_steps is not None + and len(t_steps) == 1 + ) + self._single_ts_files = check + return self._single_ts_files + + @staticmethod + def get_capped_workers(max_workers_cap, max_workers): + """Get max number of workers for a given job. Capped to global max + workers if specified + + Parameters + ---------- + max_workers_cap : int | None + Cap for job specific max_workers + max_workers : int | None + Job specific max_workers + + Returns + ------- + max_workers : int | None + job specific max_workers capped by max_workers_cap if provided + """ + if max_workers is None and max_workers_cap is None: + return max_workers + elif max_workers_cap is not None and max_workers is None: + return max_workers_cap + elif max_workers is not None and max_workers_cap is None: + return max_workers + else: + return np.min((max_workers_cap, max_workers)) + + def cap_worker_args(self, max_workers): + """Cap all workers args by max_workers""" + for v in self._worker_attrs: + capped_val = self.get_capped_workers(getattr(self, v), max_workers) + setattr(self, v, capped_val) + + @classmethod + @abstractmethod + def get_full_domain(cls, file_paths): + """Get full lat/lon grid for when target + shape are not specified""" + + @classmethod + @abstractmethod + def get_lat_lon(cls, file_paths, raster_index, invert_lat=False): + """Get lat/lon grid for requested target and shape""" + + @abstractmethod + def get_time_index(self, file_paths, max_workers=None, **kwargs): + """Get raw time index for source data""" + + @property + def input_file_info(self): + """Method to provide info about files in log output. Since NETCDF files + have single time slices printing out all the file paths is just a text + dump without much info. + + Returns + ------- + str + message to append to log output that does not include a huge info + dump of file paths + """ + msg = ( + f'source files with dates from {self.raw_time_index[0]} to ' + f'{self.raw_time_index[-1]}' + ) + return msg + + @property + def temporal_slice(self): + """Get temporal range to extract from full dataset""" + return self._temporal_slice + + @temporal_slice.setter + def temporal_slice(self, temporal_slice): + """Make sure temporal_slice is a slice. Need to do this because json + cannot save slices so we can instead save as list and then convert. + + Parameters + ---------- + temporal_slice : tuple | list | slice + Time range to extract from input data. If a list or tuple it will + be concerted to a slice. Tuple or list must have at least two + elements and no more than three, corresponding to the inputs of + slice() + """ + msg = 'temporal_slice must be tuple, list, or slice' + assert isinstance(temporal_slice, (tuple, list, slice)), msg + if isinstance(temporal_slice, slice): + self._temporal_slice = temporal_slice + else: + check = len(temporal_slice) <= 3 + msg = ( + 'If providing list or tuple for temporal_slice length must ' + 'be <= 3' + ) + assert check, msg + self._temporal_slice = slice(*temporal_slice) + if self._temporal_slice.step is None: + self._temporal_slice = slice( + self._temporal_slice.start, self._temporal_slice.stop, 1 + ) + if self._temporal_slice.start is None: + self._temporal_slice = slice( + 0, self._temporal_slice.stop, self._temporal_slice.step + ) + + @property + def file_paths(self): + """Get file paths for input data""" + return self._file_paths + + @file_paths.setter + def file_paths(self, file_paths): + """Set file paths attr and do initial glob / sort + + Parameters + ---------- + file_paths : str | list + A list of files to extract raster data from. Each file must have + the same number of timesteps. Can also pass a string with a + unix-style file path which will be passed through glob.glob + """ + self._file_paths = file_paths + if isinstance(self._file_paths, str): + if '*' in file_paths: + self._file_paths = glob.glob(self._file_paths) + else: + self._file_paths = [self._file_paths] + + msg = ( + 'No valid files provided to DataHandler. ' + f'Received file_paths={file_paths}. Aborting.' + ) + assert file_paths is not None and len(self._file_paths) > 0, msg + + self._file_paths = sorted(self._file_paths) + + @property + def ti_workers(self): + """Get max number of workers for computing time index""" + if self._ti_workers is None: + self._ti_workers = len(self._file_paths) + return self._ti_workers + + @ti_workers.setter + def ti_workers(self, val): + """Set max number of workers for computing time index""" + self._ti_workers = val + + @property + def need_full_domain(self): + """Check whether we need to get the full lat/lon grid to determine + target and shape values""" + no_raster_file = self.raster_file is None or not os.path.exists( + self.raster_file + ) + no_target_shape = self._target is None or self._grid_shape is None + need_full = no_raster_file and no_target_shape + + if need_full: + logger.info( + 'Target + shape not specified. Getting full domain ' + f'for {self.file_paths[0]}.' + ) + + return need_full + + @property + def full_raw_lat_lon(self): + """Get the full lat/lon grid without doing any latitude inversion""" + if self._full_raw_lat_lon is None and self.need_full_domain: + self._full_raw_lat_lon = self.get_full_domain(self.file_paths[:1]) + return self._full_raw_lat_lon + + @property + def raw_lat_lon(self): + """Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon + array with same ordering in last dimension. This returns the gid + without any lat inversion. + + Returns + ------- + ndarray + """ + raster_file_exists = self.raster_file is not None and os.path.exists( + self.raster_file + ) + + if self.full_raw_lat_lon is not None and raster_file_exists: + self._raw_lat_lon = self.full_raw_lat_lon[self.raster_index] + + elif self.full_raw_lat_lon is not None and not raster_file_exists: + self._raw_lat_lon = self.full_raw_lat_lon + + if self._raw_lat_lon is None: + self._raw_lat_lon = self.get_lat_lon( + self.file_paths[0:1], self.raster_index, invert_lat=False + ) + return self._raw_lat_lon + + @property + def lat_lon(self): + """Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon + array with same ordering in last dimension. This ensures that the + lower left hand corner of the domain is given by lat_lon[-1, 0] + + Returns + ------- + ndarray + """ + if self._lat_lon is None: + self._lat_lon = self.raw_lat_lon + if self.invert_lat: + self._lat_lon = self._lat_lon[::-1] + return self._lat_lon + + @lat_lon.setter + def lat_lon(self, lat_lon): + """Update lat lon""" + self._lat_lon = lat_lon + + @property + def latitude(self): + """Return latitude array""" + return self.lat_lon[..., 0] + + @property + def longitude(self): + """Return longitude array""" + return self.lat_lon[..., 1] + + @property + def invert_lat(self): + """Whether to invert the latitude axis during data extraction. This is + to enforce a descending latitude ordering so that the lower left corner + of the grid is at idx=(-1, 0) instead of idx=(0, 0)""" + if self._invert_lat is None: + lat_lon = self.raw_lat_lon + self._invert_lat = not self.lats_are_descending(lat_lon) + return self._invert_lat + + @property + def target(self): + """Get lower left corner of raster + + Returns + ------- + _target: tuple + (lat, lon) lower left corner of raster. + """ + if self._target is None: + lat_lon = self.lat_lon + if not self.lats_are_descending(lat_lon): + self._target = tuple(lat_lon[0, 0, :]) + else: + self._target = tuple(lat_lon[-1, 0, :]) + return self._target + + @target.setter + def target(self, target): + """Update target property""" + self._target = target + + @classmethod + def lats_are_descending(cls, lat_lon): + """Check if latitudes are in descending order (i.e. the target + coordinate is already at the bottom left corner) + + Parameters + ---------- + lat_lon : np.ndarray + Lat/Lon array with shape (n_lats, n_lons, 2) + + Returns + ------- + bool + """ + return lat_lon[-1, 0, 0] < lat_lon[0, 0, 0] + + @property + def grid_shape(self): + """Get shape of raster + + Returns + ------- + _grid_shape: tuple + (rows, cols) grid size. + """ + if self._grid_shape is None: + self._grid_shape = self.lat_lon.shape[:-1] + return self._grid_shape + + @grid_shape.setter + def grid_shape(self, grid_shape): + """Update grid_shape property""" + self._grid_shape = grid_shape + + @property + def source_type(self): + """Get data type for source files. Either nc or h5""" + return get_source_type(self.file_paths) + + @property + def raw_time_index(self): + """Time index for input data without time pruning. This is the base + time index for the raw input data.""" + + if self._raw_time_index is None: + check = ( + self.time_index_file is not None + and os.path.exists(self.time_index_file) + and not self.overwrite_ti_cache + ) + if check: + logger.debug( + 'Loading raw_time_index from ' f'{self.time_index_file}' + ) + with open(self.time_index_file, 'rb') as f: + self._raw_time_index = pd.DatetimeIndex(pickle.load(f)) + else: + self._raw_time_index = self._build_and_cache_time_index() + + check = ( + self._raw_time_index is not None + and (self._raw_time_index.hour == 12).all() + ) + if check: + self._raw_time_index -= pd.Timedelta(12, 'h') + elif self._raw_time_index is None: + self._raw_time_index = [None, None] + + if self._single_ts_files: + self.time_index_conflict_check() + return self._raw_time_index + + def time_index_conflict_check(self): + """Check if the number of input files and the length of the time index + is the same""" + msg = ( + f'Number of time steps ({len(self._raw_time_index)}) and files ' + f'({self.raw_tsteps}) conflict!' + ) + check = len(self._raw_time_index) == self.raw_tsteps + assert check, msg + + @property + def time_index(self): + """Time index for input data with time pruning. This is the raw time + index with a cropped range and time step applied.""" + if self._time_index is None: + self._time_index = self.raw_time_index[self.temporal_slice] + return self._time_index + + @time_index.setter + def time_index(self, time_index): + """Update time index""" + self._time_index = time_index + + @property + def time_freq_hours(self): + """Get the time frequency in hours as a float""" + ti_deltas = self.raw_time_index - np.roll(self.raw_time_index, 1) + ti_deltas_hours = pd.Series(ti_deltas).dt.total_seconds()[1:-1] / 3600 + time_freq = float(mode(ti_deltas_hours).mode) + return time_freq + + @property + def cache_pattern(self): + """Get correct cache file pattern for formatting. + + Returns + ------- + _cache_pattern : str + The cache file pattern with formatting keys included. + """ + self._cache_pattern = self._get_cache_pattern(self._cache_pattern) + return self._cache_pattern + + @cache_pattern.setter + def cache_pattern(self, cache_pattern): + """Update the cache file pattern""" + self._cache_pattern = cache_pattern + + @property + def time_index_file(self): + """Get time index file path""" + if self.source_type == 'h5': + return None + + if self.cache_pattern is not None and self._time_index_file is None: + basename = self.cache_pattern.replace('{times}', '') + basename = basename.replace('{shape}', str(len(self.file_paths))) + basename = basename.replace('_{target}', '') + basename = basename.replace('{feature}', 'time_index') + tmp = basename.split('_') + if tmp[-2].isdigit() and tmp[-1].strip('.pkl').isdigit(): + basename = '_'.join(tmp[:-1]) + '.pkl' + self._time_index_file = basename + return self._time_index_file + + def _build_and_cache_time_index(self): + """Build time index and cache if time_index_file is not None""" + now = dt.now() + logger.debug( + f'Getting time index for {len(self.file_paths)} ' + f'input files. Using ti_workers={self.ti_workers}' + f' and res_kwargs={self.res_kwargs}' + ) + self._raw_time_index = self.get_time_index( + self.file_paths, max_workers=self.ti_workers, **self.res_kwargs + ) + + if self.time_index_file is not None: + os.makedirs(os.path.dirname(self.time_index_file), exist_ok=True) + logger.debug(f'Saving raw_time_index to {self.time_index_file}') + with open(self.time_index_file, 'wb') as f: + pickle.dump(self._raw_time_index, f) + logger.debug(f'Built full time index in {dt.now() - now} seconds.') + return self._raw_time_index + + +class TrainingPrepMixIn: + """Collection of training related methods. e.g. Training + Validation + splitting, normalization""" + + @classmethod + def _split_data_indices( + cls, data, val_split=0.0, n_val_obs=None, shuffle_time=False + ): + """Split time dimension into set of training indices and validation + indices + + Parameters + ---------- + data : np.ndarray + 4D array of high res data + (spatial_1, spatial_2, temporal, features) + val_split : float + Fraction of data to separate for validation. + n_val_obs : int | None + Optional number of validation observations. If provided this + overrides val_split + shuffle_time : bool + Whether to shuffle time or not. + + Returns + ------- + training_indices : np.ndarray + Array of timestep indices used to select training data. e.g. + training_data = data[..., training_indices, :] + val_indices : np.ndarray + Array of timestep indices used to select validation data. e.g. + val_data = data[..., val_indices, :] + """ + n_observations = data.shape[2] + all_indices = np.arange(n_observations) + n_val_obs = ( + int(val_split * n_observations) if n_val_obs is None else n_val_obs + ) + + if shuffle_time: + np.random.shuffle(all_indices) + + val_indices = all_indices[:n_val_obs] + training_indices = all_indices[n_val_obs:] + + return training_indices, val_indices + + def _get_observation_index(self, data, sample_shape): + """Randomly gets spatial sample and time sample + + Parameters + ---------- + data : ndarray + Array of data to sample + (spatial_1, spatial_2, temporal, n_features) + sample_shape : tuple + Size of observation to sample + (n_lats, n_lons, n_timesteps) + + Returns + ------- + observation_index : tuple + Tuple of sampled spatial grid, time slice, and features indices. + Used to get single observation like self.data[observation_index] + """ + spatial_slice = uniform_box_sampler(data, sample_shape[:2]) + temporal_slice = uniform_time_sampler(data, sample_shape[2]) + return tuple( + [*spatial_slice, temporal_slice, np.arange(data.shape[-1])] + ) + + @classmethod + def _unnormalize(cls, data, val_data, means, stds): + """Remove normalization from stored means and stds + + Parameters + ---------- + data : np.ndarray + Array of training data. + (spatial_1, spatial_2, temporal, n_features) + val_data : np.ndarray + Array of validation data. + (spatial_1, spatial_2, temporal, n_features) + means : np.ndarray + dimensions (features) + array of means for all features with same ordering as data features + stds : np.ndarray + dimensions (features) + array of means for all features with same ordering as data features + """ + val_data = (val_data * stds) + means + data = (data * stds) + means + return data, val_data + + def _normalize_data(self, data, val_data, feature_index, mean, std): + """Normalize data with initialized mean and standard deviation for a + specific feature + + Parameters + ---------- + data : np.ndarray + Array of training data. + (spatial_1, spatial_2, temporal, n_features) + val_data : np.ndarray + Array of validation data. + (spatial_1, spatial_2, temporal, n_features) + feature_index : int + index of feature to be normalized + mean : float32 + specified mean of associated feature + std : float32 + specificed standard deviation for associated feature + """ + + if val_data is not None: + val_data[..., feature_index] -= mean + data[..., feature_index] -= mean + + if std > 0: + if val_data is not None: + val_data[..., feature_index] /= std + data[..., feature_index] /= std + else: + msg = ( + f'Standard Deviation is zero for feature #{feature_index + 1}' + ) + logger.warning(msg) + warnings.warn(msg) + + def _normalize(self, data, val_data, means, stds, max_workers=None): + """Normalize all data features + + Parameters + ---------- + data : np.ndarray + Array of training data. + (spatial_1, spatial_2, temporal, n_features) + val_data : np.ndarray + Array of validation data. + (spatial_1, spatial_2, temporal, n_features) + means : np.ndarray + dimensions (features) + array of means for all features with same ordering as data features + stds : np.ndarray + dimensions (features) + array of means for all features with same ordering as data features + max_workers : int | None + Number of workers to use in thread pool for nomalization. + """ + logger.info(f'Normalizing {data.shape[-1]} features.') + if max_workers == 1: + for i in range(data.shape[-1]): + self._normalize_data(data, val_data, i, means[i], stds[i]) + else: + self.parallel_normalization( + data, val_data, means, stds, max_workers=max_workers + ) + + def parallel_normalization( + self, data, val_data, means, stds, max_workers=None + ): + """Run normalization of features in parallel + + Parameters + ---------- + data : np.ndarray + Array of training data. + (spatial_1, spatial_2, temporal, n_features) + val_data : np.ndarray + Array of validation data. + (spatial_1, spatial_2, temporal, n_features) + means : np.ndarray + dimensions (features) + array of means for all features with same ordering as data features + stds : np.ndarray + dimensions (features) + array of means for all features with same ordering as data features + max_workers : int | None + Max number of workers to use for normalizing features + """ + + with ThreadPoolExecutor(max_workers=max_workers) as exe: + futures = {} + now = dt.now() + for i in range(data.shape[-1]): + future = exe.submit( + self._normalize_data, data, val_data, i, means[i], stds[i] + ) + futures[future] = i + + logger.info( + f'Started normalizing {data.shape[-1]} features ' + f'in {dt.now() - now}.' + ) + + for i, future in enumerate(as_completed(futures)): + try: + future.result() + except Exception as e: + msg = ( + 'Error while normalizing future number ' + f'{futures[future]}.' + ) + logger.exception(msg) + raise RuntimeError(msg) from e + logger.debug( + f'{i+1} out of {data.shape[-1]} features ' 'normalized.' + ) diff --git a/sup3r/preprocessing/data_handling/nc_data_handling.py b/sup3r/preprocessing/data_handling/nc_data_handling.py new file mode 100644 index 000000000..d655f4c5f --- /dev/null +++ b/sup3r/preprocessing/data_handling/nc_data_handling.py @@ -0,0 +1,816 @@ +"""Data handling for netcdf files. +@author: bbenton +""" + +import logging +import os +import warnings +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime as dt +from typing import ClassVar + +import numpy as np +import pandas as pd +import xarray as xr +from rex import Resource +from scipy.ndimage.filters import gaussian_filter +from scipy.spatial import KDTree +from scipy.stats import mode + +from sup3r.preprocessing.data_handling.base import DataHandler, DataHandlerDC +from sup3r.preprocessing.feature_handling import (BVFreqMon, BVFreqSquaredNC, + ClearSkyRatioCC, Feature, + InverseMonNC, LatLonNC, + PotentialTempNC, PressureNC, + Rews, Shear, Tas, TasMax, + TasMin, TempNC, TempNCforCC, + UWind, VWind, + WinddirectionNC, WindspeedNC) +from sup3r.utilities.interpolation import Interpolator +from sup3r.utilities.utilities import (estimate_max_workers, get_time_dim_name, + np_to_pd_times) + +np.random.seed(42) + +logger = logging.getLogger(__name__) + + +class DataHandlerNC(DataHandler): + """Data Handler for NETCDF data""" + + CHUNKS: ClassVar[dict] = { + 'XTIME': 100, + 'XLAT': 150, + 'XLON': 150, + 'south_north': 150, + 'west_east': 150, + 'Time': 100, + } + """CHUNKS sets the chunk sizes to extract from the data in each dimension. + Chunk sizes that approximately match the data volume being extracted + typically results in the most efficient IO.""" + + def __init__(self, *args, xr_chunks=None, **kwargs): + """Initialize NETCDF data handler. + + Parameters + ---------- + *args : list + Same ordered required arguments as DataHandler parent class. + xr_chunks : int | "auto" | tuple | dict | None + kwarg that goes to xr.DataArray.chunk(chunks=xr_chunks). Chunk + sizes that approximately match the data volume being extracted + typically results in the most efficient IO. If not provided, this + defaults to the class CHUNKS attribute. + **kwargs : list + Same optional keyword arguments as DataHandler parent class. + """ + if xr_chunks is not None: + self.CHUNKS = xr_chunks + + super().__init__(*args, **kwargs) + + @property + def extract_workers(self): + """Get upper bound for extract workers based on memory limits. Used to + extract data from source dataset""" + # This large multiplier is due to the height interpolation allocating + # multiple arrays with up to 60 vertical levels + proc_mem = 6 * 64 * self.grid_mem * len(self.time_index) + proc_mem /= len(self.time_chunks) + n_procs = len(self.time_chunks) * len(self.extract_features) + n_procs = int(np.ceil(n_procs)) + extract_workers = estimate_max_workers( + self._extract_workers, proc_mem, n_procs + ) + return extract_workers + + @classmethod + def source_handler(cls, file_paths, **kwargs): + """Xarray data handler + + Note that xarray appears to treat open file handlers as singletons + within a threadpool, so its okay to open this source_handler without a + context handler or a .close() statement. + + Parameters + ---------- + file_paths : str | list + paths to data files + kwargs : dict + kwargs passed to source handler for data extraction. e.g. This + could be {'parallel': True, + 'chunks': {'south_north': 120, 'west_east': 120}} + which then gets passed to xr.open_mfdataset(file, **kwargs) + + Returns + ------- + data : xarray.Dataset + """ + time_key = get_time_dim_name(file_paths[0]) + default_kws = { + 'combine': 'nested', + 'concat_dim': time_key, + 'chunks': cls.CHUNKS, + } + default_kws.update(kwargs) + return xr.open_mfdataset(file_paths, **default_kws) + + @classmethod + def get_file_times(cls, file_paths, **kwargs): + """Get time index from data files + + Parameters + ---------- + file_paths : list + path to data file + kwargs : dict + kwargs passed to source handler for data extraction. e.g. This + could be {'parallel': True, + 'chunks': {'south_north': 120, 'west_east': 120}} + which then gets passed to xr.open_mfdataset(file, **kwargs) + + Returns + ------- + time_index : pd.Datetimeindex + List of times as a Datetimeindex + """ + handle = cls.source_handler(file_paths, **kwargs) + + if hasattr(handle, 'Times'): + time_index = np_to_pd_times(handle.Times.values) + elif hasattr(handle, 'indexes') and 'time' in handle.indexes: + time_index = handle.indexes['time'] + if not isinstance(time_index, pd.DatetimeIndex): + time_index = time_index.to_datetimeindex() + elif hasattr(handle, 'times'): + time_index = np_to_pd_times(handle.times.values) + else: + msg = ( + f'Could not get time_index for {file_paths}. ' + 'Assuming time independence.' + ) + time_index = None + logger.warning(msg) + warnings.warn(msg) + + return time_index + + @classmethod + def get_time_index(cls, file_paths, max_workers=None, **kwargs): + """Get time index from data files + + Parameters + ---------- + file_paths : list + path to data file + max_workers : int | None + Max number of workers to use for parallel time index building + kwargs : dict + kwargs passed to source handler for data extraction. e.g. This + could be {'parallel': True, + 'chunks': {'south_north': 120, 'west_east': 120}} + which then gets passed to xr.open_mfdataset(file, **kwargs) + + Returns + ------- + time_index : pd.Datetimeindex + List of times as a Datetimeindex + """ + max_workers = ( + len(file_paths) + if max_workers is None + else np.min((max_workers, len(file_paths))) + ) + if max_workers == 1: + return cls.get_file_times(file_paths, **kwargs) + ti = {} + with ThreadPoolExecutor(max_workers=max_workers) as exe: + futures = {} + now = dt.now() + for i, f in enumerate(file_paths): + future = exe.submit(cls.get_file_times, [f], **kwargs) + futures[future] = {'idx': i, 'file': f} + + logger.info( + f'Started building time index from {len(file_paths)} ' + f'files in {dt.now() - now}.' + ) + + for i, future in enumerate(as_completed(futures)): + try: + val = future.result() + if val is not None: + ti[futures[future]['idx']] = list(val) + except Exception as e: + msg = ( + 'Error while getting time index from file ' + f'{futures[future]["file"]}.' + ) + logger.exception(msg) + raise RuntimeError(msg) from e + logger.debug(f'Stored {i+1} out of {len(futures)} file times') + times = np.concatenate(list(ti.values())) + return pd.DatetimeIndex(sorted(set(times))) + + @classmethod + def feature_registry(cls): + """Registry of methods for computing features + + Returns + ------- + dict + Method registry + """ + registry = { + 'BVF2_(.*)m': BVFreqSquaredNC, + 'BVF_MO_(.*)m': BVFreqMon, + 'RMOL': InverseMonNC, + 'U_(.*)': UWind, + 'V_(.*)': VWind, + 'Windspeed_(.*)m': WindspeedNC, + 'Winddirection_(.*)m': WinddirectionNC, + 'lat_lon': LatLonNC, + 'Shear_(.*)m': Shear, + 'REWS_(.*)m': Rews, + 'Temperature_(.*)m': TempNC, + 'Pressure_(.*)m': PressureNC, + 'PotentialTemp_(.*)m': PotentialTempNC, + 'PT_(.*)m': PotentialTempNC, + 'topography': ['HGT', 'orog'], + } + return registry + + @classmethod + def extract_feature( + cls, + file_paths, + raster_index, + feature, + time_slice=slice(None), + **kwargs, + ): + """Extract single feature from data source. The requested feature + can match exactly to one found in the source data or can have a + matching prefix with a suffix specifying the height or pressure level + to interpolate to. e.g. feature=U_100m -> interpolate exact match U to + 100 meters. + + Parameters + ---------- + file_paths : list + path to data file + raster_index : ndarray + Raster index array + feature : str + Feature to extract from data + time_slice : slice + slice of time to extract + kwargs : dict + kwargs passed to source handler for data extraction. e.g. This + could be {'parallel': True, + 'chunks': {'south_north': 120, 'west_east': 120}} + which then gets passed to xr.open_mfdataset(file, **kwargs) + + Returns + ------- + ndarray + Data array for extracted feature + (spatial_1, spatial_2, temporal) + """ + logger.debug( + f'Extracting {feature} with time_slice={time_slice}, ' + f'raster_index={raster_index}, kwargs={kwargs}.' + ) + handle = cls.source_handler(file_paths, **kwargs) + f_info = Feature(feature, handle) + interp_height = f_info.height + interp_pressure = f_info.pressure + basename = f_info.basename + + if feature in handle or feature.lower() in handle: + feat_key = feature if feature in handle else feature.lower() + fdata = cls.direct_extract( + handle, feat_key, raster_index, time_slice + ) + + elif basename in handle or basename.lower() in handle: + feat_key = basename if basename in handle else basename.lower() + if interp_height is not None: + fdata = Interpolator.interp_var_to_height( + handle, + feat_key, + raster_index, + np.float32(interp_height), + time_slice, + ) + elif interp_pressure is not None: + fdata = Interpolator.interp_var_to_pressure( + handle, + feat_key, + raster_index, + np.float32(interp_pressure), + time_slice, + ) + + else: + msg = f'{feature} cannot be extracted from source data.' + logger.exception(msg) + raise ValueError(msg) + + fdata = np.transpose(fdata, (1, 2, 0)) + return fdata.astype(np.float32) + + @classmethod + def direct_extract(cls, handle, feature, raster_index, time_slice): + """Extract requested feature directly from source data, rather than + interpolating to a requested height or pressure level + + Parameters + ---------- + handle : xarray + netcdf data object + feature : str + Name of feature to extract directly from source handler + raster_index : list + List of slices for raster index of spatial domain + time_slice : slice + slice of time to extract + + Returns + ------- + fdata : ndarray + Data array for requested feature + """ + # Sometimes xarray returns fields with (Times, time, lats, lons) + # with a single entry in the 'time' dimension so we include this [0] + if len(handle[feature].dims) == 4: + idx = tuple([time_slice, 0, *raster_index]) + elif len(handle[feature].dims) == 3: + idx = tuple([time_slice, *raster_index]) + else: + idx = tuple(raster_index) + fdata = np.array(handle[feature][idx], dtype=np.float32) + if len(fdata.shape) == 2: + fdata = np.expand_dims(fdata, axis=0) + return fdata + + @classmethod + def get_full_domain(cls, file_paths): + """Get full shape and min available lat lon. To simplify processing + of full domain without needing to specify target and shape. + + Parameters + ---------- + file_paths : list + List of data file paths + + Returns + ------- + target : tuple + (lat, lon) for lower left corner + lat_lon : ndarray + Raw lat/lon array for entire domain + """ + return cls.get_lat_lon(file_paths, [slice(None), slice(None)]) + + @staticmethod + def get_closest_lat_lon(lat_lon, target): + """Get closest indices to target lat lon to use for lower left corner + of raster index + + Parameters + ---------- + lat_lon : ndarray + Array of lat/lon + (spatial_1, spatial_2, 2) + Last dimension in order of (lat, lon) + target : tuple + (lat, lon) for lower left corner + + Returns + ------- + row : int + row index for closest lat/lon to target lat/lon + col : int + col index for closest lat/lon to target lat/lon + """ + # shape of ll2 is (n, 2) where axis=1 is (lat, lon) + ll2 = np.vstack( + (lat_lon[..., 0].flatten(), lat_lon[..., 1].flatten()) + ).T + tree = KDTree(ll2) + _, i = tree.query(np.array(target)) + row, col = np.where( + (lat_lon[..., 0] == ll2[i, 0]) & (lat_lon[..., 1] == ll2[i, 1]) + ) + row = row[0] + col = col[0] + return row, col + + @classmethod + def compute_raster_index(cls, file_paths, target, grid_shape): + """Get raster index for a given target and shape + + Parameters + ---------- + file_paths : list + List of input data file paths + target : tuple + Target coordinate for lower left corner of extracted data + grid_shape : tuple + Shape out extracted data + + Returns + ------- + list + List of slices corresponding to extracted data region + """ + lat_lon = cls.get_lat_lon( + file_paths[:1], [slice(None), slice(None)], invert_lat=False + ) + cls._check_grid_extent(target, grid_shape, lat_lon) + + row, col = cls.get_closest_lat_lon(lat_lon, target) + + closest = tuple(lat_lon[row, col]) + logger.debug(f'Found closest coordinate {closest} to target={target}') + if np.hypot(closest[0] - target[0], closest[1] - target[1]) > 1: + msg = 'Closest coordinate to target is more than 1 degree away' + logger.warning(msg) + warnings.warn(msg) + + if cls.lats_are_descending(lat_lon): + row_end = row + 1 + row_start = row_end - grid_shape[0] + else: + row_end = row + grid_shape[0] + row_start = row + raster_index = [ + slice(row_start, row_end), + slice(col, col + grid_shape[1]), + ] + cls._validate_raster_shape(target, grid_shape, lat_lon, raster_index) + return raster_index + + @classmethod + def _check_grid_extent(cls, target, grid_shape, lat_lon): + """Make sure the requested target coordinate lies within the available + lat/lon grid. + + Parameters + ---------- + target : tuple + Target coordinate for lower left corner of extracted data + grid_shape : tuple + Shape out extracted data + lat_lon : ndarray + Array of lat/lon coordinates for entire available grid. Used to + check whether computed raster only includes coordinates within this + grid. + """ + min_lat = np.min(lat_lon[..., 0]) + min_lon = np.min(lat_lon[..., 1]) + max_lat = np.max(lat_lon[..., 0]) + max_lon = np.max(lat_lon[..., 1]) + logger.debug( + 'Calculating raster index from WRF file ' + f'for shape {grid_shape} and target {target}' + ) + logger.debug( + f'lat/lon (min, max): {min_lat}/{min_lon}, ' f'{max_lat}/{max_lon}' + ) + msg = ( + f'target {target} out of bounds with min lat/lon ' + f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}' + ) + assert ( + min_lat <= target[0] <= max_lat and min_lon <= target[1] <= max_lon + ), msg + + @classmethod + def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): + """Make sure the computed raster_index only includes coordinates within + the available grid + + Parameters + ---------- + target : tuple + Target coordinate for lower left corner of extracted data + grid_shape : tuple + Shape out extracted data + lat_lon : ndarray + Array of lat/lon coordinates for entire available grid. Used to + check whether computed raster only includes coordinates within this + grid. + raster_index : list + List of slices selecting region from entire available grid. + """ + if ( + raster_index[0].stop > lat_lon.shape[0] + or raster_index[1].stop > lat_lon.shape[1] + or raster_index[0].start < 0 + or raster_index[1].start < 0 + ): + msg = ( + f'Invalid target {target}, shape {grid_shape}, and raster ' + f'{raster_index} for data domain of size ' + f'{lat_lon.shape[:-1]} with lower left corner ' + f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' + f' and upper right corner ({np.max(lat_lon[..., 0])}, ' + f'{np.max(lat_lon[..., 1])}).' + ) + raise ValueError(msg) + + def get_raster_index(self): + """Get raster index for file data. Here we assume the list of paths in + file_paths all have data with the same spatial domain. We use the first + file in the list to compute the raster. + + Returns + ------- + raster_index : np.ndarray + 2D array of grid indices + """ + self.raster_file = ( + self.raster_file + if self.raster_file is None + else self.raster_file.replace('.txt', '.npy') + ) + if self.raster_file is not None and os.path.exists(self.raster_file): + logger.debug( + f'Loading raster index: {self.raster_file} ' + f'for {self.input_file_info}' + ) + raster_index = np.load(self.raster_file, allow_pickle=True) + raster_index = list(raster_index) + else: + check = self.grid_shape is not None and self.target is not None + msg = ( + 'Must provide raster file or shape + target to get ' + 'raster index' + ) + assert check, msg + raster_index = self.compute_raster_index( + self.file_paths, self.target, self.grid_shape + ) + logger.debug( + 'Found raster index with row, col slices: {}'.format( + raster_index + ) + ) + + if self.raster_file is not None: + basedir = os.path.dirname(self.raster_file) + if not os.path.exists(basedir): + os.makedirs(basedir) + logger.debug(f'Saving raster index: {self.raster_file}') + np.save(self.raster_file.replace('.txt', '.npy'), raster_index) + + return raster_index + + +class DataHandlerNCforCC(DataHandlerNC): + """Data Handler for NETCDF climate change data""" + + CHUNKS = {'time': 5, 'lat': 20, 'lon': 20} + """CHUNKS sets the chunk sizes to extract from the data in each dimension. + Chunk sizes that approximately match the data volume being extracted + typically results in the most efficient IO.""" + + def __init__( + self, + *args, + nsrdb_source_fp=None, + nsrdb_agg=1, + nsrdb_smoothing=0, + **kwargs, + ): + """Initialize NETCDF data handler for climate change data. + + Parameters + ---------- + *args : list + Same ordered required arguments as DataHandler parent class. + nsrdb_source_fp : str | None + Optional NSRDB source h5 file to retrieve clearsky_ghi from to + calculate CC clearsky_ratio along with rsds (ghi) from the CC + netcdf file. + nsrdb_agg : int + Optional number of NSRDB source pixels to aggregate clearsky_ghi + from to a single climate change netcdf pixel. This can be used if + the CC.nc data is at a much coarser resolution than the source + nsrdb data. + nsrdb_smoothing : float + Optional gaussian filter smoothing factor to smooth out + clearsky_ghi from high-resolution nsrdb source data. This is + typically done because spatially aggregated nsrdb data is still + usually rougher than CC irradiance data. + **kwargs : list + Same optional keyword arguments as DataHandler parent class. + """ + self._nsrdb_source_fp = nsrdb_source_fp + self._nsrdb_agg = nsrdb_agg + self._nsrdb_smoothing = nsrdb_smoothing + super().__init__(*args, **kwargs) + + @classmethod + def feature_registry(cls): + """Registry of methods for computing features or extracting renamed + features + + Returns + ------- + dict + Method registry + """ + registry = { + 'U_(.*)': 'ua_(.*)', + 'V_(.*)': 'va_(.*)', + 'Windspeed_(.*)m': WindspeedNC, + 'Winddirection_(.*)m': WinddirectionNC, + 'topography': 'orog', + 'relativehumidity_2m': 'hurs', + 'relativehumidity_min_2m': 'hursmin', + 'relativehumidity_max_2m': 'hursmax', + 'clearsky_ratio': ClearSkyRatioCC, + 'lat_lon': LatLonNC, + 'Pressure_(.*)': 'plev_(.*)', + 'Temperature_(.*)': TempNCforCC, + 'temperature_2m': Tas, + 'temperature_max_2m': TasMax, + 'temperature_min_2m': TasMin, + } + return registry + + @classmethod + def source_handler(cls, file_paths, **kwargs): + """Xarray data handler + + Note that xarray appears to treat open file handlers as singletons + within a threadpool, so its okay to open this source_handler without a + context handler or a .close() statement. + + Parameters + ---------- + file_paths : str | list + paths to data files + kwargs : dict + kwargs passed to source handler for data extraction. e.g. This + could be {'parallel': True, + 'chunks': {'south_north': 120, 'west_east': 120}} + which then gets passed to xr.open_mfdataset(file, **kwargs) + + Returns + ------- + data : xarray.Dataset + """ + default_kws = {'chunks': cls.CHUNKS} + default_kws.update(kwargs) + return xr.open_mfdataset(file_paths, **default_kws) + + def run_data_extraction(self): + """Run the raw dataset extraction process from disk to raw + un-manipulated datasets. + + Includes a special method to extract clearsky_ghi from a exogenous + NSRDB source h5 file (required to compute clearsky_ratio). + """ + get_clearsky = False + if 'clearsky_ghi' in self.raw_features: + get_clearsky = True + self._raw_features.remove('clearsky_ghi') + + super().run_data_extraction() + + if get_clearsky: + cs_ghi = self.get_clearsky_ghi() + + # clearsky ghi is extracted at the proper starting time index so + # the time chunks should start at 0 + tc0 = self.time_chunks[0].start + cs_ghi_time_chunks = [ + slice(tc.start - tc0, tc.stop - tc0, tc.step) + for tc in self.time_chunks + ] + for it, tslice in enumerate(cs_ghi_time_chunks): + self._raw_data[it]['clearsky_ghi'] = cs_ghi[..., tslice] + + self._raw_features.append('clearsky_ghi') + + def get_clearsky_ghi(self): + """Get clearsky ghi from an exogenous NSRDB source h5 file at the + target CC meta data and time index. + + Returns + ------- + cs_ghi : np.ndarray + Clearsky ghi (W/m2) from the nsrdb_source_fp h5 source file. Data + shape is (lat, lon, time) where time is daily average values. + """ + + msg = ( + 'Need nsrdb_source_fp input arg as a valid filepath to ' + 'retrieve clearsky_ghi (maybe for clearsky_ratio) but ' + 'received: {}'.format(self._nsrdb_source_fp) + ) + assert self._nsrdb_source_fp is not None, msg + assert os.path.exists(self._nsrdb_source_fp), msg + + msg = ( + 'Can only handle source CC data in hourly frequency but ' + 'received daily frequency of {}hrs (should be 24) ' + 'with raw time index: {}'.format( + self.time_freq_hours, self.raw_time_index + ) + ) + assert self.time_freq_hours == 24.0, msg + + msg = ( + 'Can only handle source CC data with temporal_slice.step == 1 ' + 'but received: {}'.format(self.temporal_slice.step) + ) + assert (self.temporal_slice.step is None) | ( + self.temporal_slice.step == 1 + ), msg + + with Resource(self._nsrdb_source_fp) as res: + ti_nsrdb = res.time_index + meta_nsrdb = res.meta + + ti_deltas = ti_nsrdb - np.roll(ti_nsrdb, 1) + ti_deltas_hours = pd.Series(ti_deltas).dt.total_seconds()[1:-1] / 3600 + time_freq = float(mode(ti_deltas_hours).mode) + t_start = self.temporal_slice.start or 0 + t_end_target = self.temporal_slice.stop or len(self.raw_time_index) + t_start = int(t_start * 24 * (1 / time_freq)) + t_end = int(t_end_target * 24 * (1 / time_freq)) + t_end = np.minimum(t_end, len(ti_nsrdb)) + t_slice = slice(t_start, t_end) + + # pylint: disable=E1136 + lat = self.lat_lon[:, :, 0].flatten() + lon = self.lat_lon[:, :, 1].flatten() + cc_meta = np.vstack((lat, lon)).T + + tree = KDTree(meta_nsrdb[['latitude', 'longitude']]) + _, i = tree.query(cc_meta, k=self._nsrdb_agg) + if len(i.shape) == 1: + i = np.expand_dims(i, axis=1) + + logger.info( + 'Extracting clearsky_ghi data from "{}" with time slice ' + '{} and {} locations with agg factor {}.'.format( + os.path.basename(self._nsrdb_source_fp), + t_slice, + i.shape[0], + i.shape[1], + ) + ) + + cs_shape = i.shape + with Resource(self._nsrdb_source_fp) as res: + cs_ghi = res['clearsky_ghi', t_slice, i.flatten()] + + cs_ghi = cs_ghi.reshape((len(cs_ghi), *cs_shape)) + cs_ghi = cs_ghi.mean(axis=-1) + + windows = np.array_split( + np.arange(len(cs_ghi)), len(cs_ghi) // (24 // time_freq) + ) + cs_ghi = [cs_ghi[window].mean(axis=0) for window in windows] + cs_ghi = np.vstack(cs_ghi) + cs_ghi = cs_ghi.reshape((len(cs_ghi), *tuple(self.grid_shape))) + cs_ghi = np.transpose(cs_ghi, axes=(1, 2, 0)) + + if self.invert_lat: + cs_ghi = cs_ghi[::-1] + + logger.info( + 'Smoothing nsrdb clearsky ghi with a factor of {}'.format( + self._nsrdb_smoothing + ) + ) + for iday in range(cs_ghi.shape[-1]): + cs_ghi[..., iday] = gaussian_filter( + cs_ghi[..., iday], self._nsrdb_smoothing, mode='nearest' + ) + + if cs_ghi.shape[-1] < t_end_target: + n = int(np.ceil(t_end_target / cs_ghi.shape[-1])) + cs_ghi = np.repeat(cs_ghi, n, axis=2) + cs_ghi = cs_ghi[..., :t_end_target] + + logger.info( + 'Reshaped clearsky_ghi data to final shape {} to ' + 'correspond with CC daily average data over source ' + 'temporal_slice {} with (lat, lon) grid shape of {}'.format( + cs_ghi.shape, self.temporal_slice, self.grid_shape + ) + ) + + return cs_ghi + + +class DataHandlerDCforNC(DataHandlerNC, DataHandlerDC): + """Data centric data handler for NETCDF files""" diff --git a/sup3r/preprocessing/dual_batch_handling.py b/sup3r/preprocessing/dual_batch_handling.py new file mode 100644 index 000000000..a6a1d34d4 --- /dev/null +++ b/sup3r/preprocessing/dual_batch_handling.py @@ -0,0 +1,285 @@ +"""Batch handling classes for dual data handlers""" +import logging + +import numpy as np + +from sup3r.preprocessing.batch_handling import ( + Batch, + BatchHandler, + ValidationData, +) +from sup3r.utilities.utilities import uniform_box_sampler, uniform_time_sampler + +logger = logging.getLogger(__name__) + + +class DualValidationData(ValidationData): + """Iterator for validation data for training with dual data handler""" + + # Classes to use for handling an individual batch obj. + BATCH_CLASS = Batch + + def _get_val_indices(self): + """List of dicts to index each validation data observation across all + handlers + + Returns + ------- + val_indices : list[dict] + List of dicts with handler_index and tuple_index. The tuple index + is used to get validation data observation with + data[tuple_index] + """ + + val_indices = [] + for i, h in enumerate(self.data_handlers): + if h.hr_val_data is not None: + for _ in range(h.hr_val_data.shape[2]): + spatial_slice = uniform_box_sampler( + h.lr_val_data, self.lr_sample_shape[:2]) + temporal_slice = uniform_time_sampler( + h.lr_val_data, self.lr_sample_shape[2]) + lr_index = tuple([ + *spatial_slice, temporal_slice, + np.arange(h.lr_val_data.shape[-1]) + ]) + hr_index = [] + for s in lr_index[:2]: + hr_index.append( + slice( + s.start * self.s_enhance, + s.stop * self.s_enhance, + )) + for s in lr_index[2:-1]: + hr_index.append( + slice( + s.start * self.t_enhance, + s.stop * self.t_enhance, + )) + hr_index.append(lr_index[-1]) + hr_index = tuple(hr_index) + val_indices.append({ + 'handler_index': i, + 'hr_index': hr_index, + 'lr_index': lr_index, + }) + return val_indices + + @property + def shape(self): + """Shape of full validation dataset across all handlers + + Returns + ------- + shape : tuple + (spatial_1, spatial_2, temporal, features) + With temporal extent equal to the sum across all data handlers time + dimension + """ + time_steps = 0 + for h in self.data_handlers: + time_steps += h.hr_val_data.shape[2] + return ( + self.data_handlers[0].hr_val_data.shape[0], + self.data_handlers[0].hr_val_data.shape[1], + time_steps, + self.data_handlers[0].hr_val_data.shape[3], + ) + + @property + def hr_sample_shape(self): + """Get sample shape for high_res data""" + return self.data_handlers[0].hr_dh.sample_shape + + @property + def lr_sample_shape(self): + """Get sample shape for low_res data""" + return self.data_handlers[0].lr_dh.sample_shape + + def __next__(self): + """Get validation data batch + + Returns + ------- + batch : Batch + validation data batch with low and high res data each with + n_observations = batch_size + """ + self.current_batch_indices = [] + if self._remaining_observations > 0: + if self._remaining_observations > self.batch_size: + n_obs = self.batch_size + else: + n_obs = self._remaining_observations + + high_res = np.zeros( + ( + n_obs, + self.hr_sample_shape[0], + self.hr_sample_shape[1], + self.hr_sample_shape[2], + self.data_handlers[0].shape[-1], + ), + dtype=np.float32, + ) + low_res = np.zeros( + ( + n_obs, + self.lr_sample_shape[0], + self.lr_sample_shape[1], + self.lr_sample_shape[2], + self.data_handlers[0].shape[-1], + ), + dtype=np.float32, + ) + for i in range(high_res.shape[0]): + val_index = self.val_indices[self._i + i] + high_res[i, ...] = self.data_handlers[val_index[ + 'handler_index']].hr_val_data[val_index['hr_index']] + low_res[i, ...] = self.data_handlers[val_index[ + 'handler_index']].lr_val_data[val_index['lr_index']] + self._remaining_observations -= 1 + self.current_batch_indices.append(val_index) + + # This checks if there is only a single timestep. If so this means + # we are using a spatial batch handler which uses 4D batches. + if self.sample_shape[2] == 1: + high_res = high_res[..., 0, :] + low_res = low_res[..., 0, :] + + high_res = self.BATCH_CLASS.reduce_features( + high_res, self.output_features_ind) + batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) + self._i += 1 + return batch + else: + raise StopIteration + + +class DualBatchHandler(BatchHandler): + """Batch handling class for dual data handlers""" + + BATCH_CLASS = Batch + VAL_CLASS = DualValidationData + + @property + def hr_sample_shape(self): + """Get sample shape for high_res data""" + return self.data_handlers[0].hr_dh.sample_shape + + @property + def lr_sample_shape(self): + """Get sample shape for low_res data""" + return self.data_handlers[0].lr_dh.sample_shape + + def __iter__(self): + self._i = 0 + return self + + def __next__(self): + """Get the next iterator output. + + Returns + ------- + batch : Batch + Batch object with batch.low_res and batch.high_res attributes + with the appropriate subsampling of interpolated ERA. + """ + self.current_batch_indices = [] + if self._i < self.n_batches: + handler_index = np.random.randint(0, len(self.data_handlers)) + self.current_handler_index = handler_index + handler = self.data_handlers[handler_index] + high_res = np.zeros( + ( + self.batch_size, + self.hr_sample_shape[0], + self.hr_sample_shape[1], + self.hr_sample_shape[2], + self.shape[-1], + ), + dtype=np.float32, + ) + low_res = np.zeros( + ( + self.batch_size, + self.lr_sample_shape[0], + self.lr_sample_shape[1], + self.lr_sample_shape[2], + self.shape[-1], + ), + dtype=np.float32, + ) + + for i in range(self.batch_size): + high_res[i, ...], low_res[i, ...] = handler.get_next() + self.current_batch_indices.append(handler.current_obs_index) + + high_res = self.BATCH_CLASS.reduce_features( + high_res, self.output_features_ind) + batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) + + self._i += 1 + return batch + else: + raise StopIteration + + +class SpatialDualBatchHandler(DualBatchHandler): + """Batch handling class for h5 data as high res (usually WTK) and ERA5 as + low res""" + + BATCH_CLASS = Batch + VAL_CLASS = DualValidationData + + def __iter__(self): + self._i = 0 + return self + + def __next__(self): + """Get the next iterator output. + + Returns + ------- + batch : Batch + Batch object with batch.low_res and batch.high_res attributes + with the appropriate subsampling of interpolated ERA. + """ + self.current_batch_indices = [] + if self._i < self.n_batches: + handler_index = np.random.randint(0, len(self.data_handlers)) + self.current_handler_index = handler_index + handler = self.data_handlers[handler_index] + high_res = np.zeros( + ( + self.batch_size, + self.hr_sample_shape[0], + self.hr_sample_shape[1], + self.shape[-1], + ), + dtype=np.float32, + ) + low_res = np.zeros( + ( + self.batch_size, + self.lr_sample_shape[0], + self.lr_sample_shape[1], + self.shape[-1], + ), + dtype=np.float32, + ) + + for i in range(self.batch_size): + hr, lr = handler.get_next() + high_res[i, ...] = hr[..., 0, :] + low_res[i, ...] = lr[..., 0, :] + self.current_batch_indices.append(handler.current_obs_index) + + high_res = self.BATCH_CLASS.reduce_features( + high_res, self.output_features_ind) + batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) + + self._i += 1 + return batch + else: + raise StopIteration diff --git a/sup3r/preprocessing/feature_handling.py b/sup3r/preprocessing/feature_handling.py index fc6c620b9..df784aa7a 100644 --- a/sup3r/preprocessing/feature_handling.py +++ b/sup3r/preprocessing/feature_handling.py @@ -1615,6 +1615,44 @@ def _pattern_lookup(cls, feature): break return out + @classmethod + def _lookup(cls, out, feature, handle_features=None): + """Lookup feature in feature registry + + Parameters + ---------- + out : None + Candidate registry method for feature + feature : str + Feature to lookup in registry + handle_features : list + List of feature names (datasets) available in the source file. If + feature is found explicitly in this list, height/pressure suffixes + will not be appended to the output. + + Returns + ------- + method | None + Feature registry method corresponding to feature + """ + if isinstance(out, list): + for v in out: + if v in handle_features: + return lambda x: [v] + + if out in handle_features: + return lambda x: [out] + + height = Feature.get_height(feature) + if height is not None: + out = out.split('(.*)')[0] + f'{height}m' + + pressure = Feature.get_pressure(feature) + if pressure is not None: + out = out.split('(.*)')[0] + f'{pressure}pa' + + return lambda x: [out] + @classmethod def lookup(cls, feature, attr_name, handle_features=None): """Lookup feature in feature registry @@ -1635,7 +1673,6 @@ def lookup(cls, feature, attr_name, handle_features=None): method | None Feature registry method corresponding to feature """ - handle_features = handle_features or [] out = cls._exact_lookup(feature) @@ -1645,23 +1682,12 @@ def lookup(cls, feature, attr_name, handle_features=None): if out is None: return None - if not isinstance(out, str): + if not isinstance(out, (str, list)): return getattr(out, attr_name, None) elif attr_name == 'inputs': - if out in handle_features: - return lambda x: [out] - - height = Feature.get_height(feature) - if height is not None: - out = out.split('(.*)')[0] + f'{height}m' - - pressure = Feature.get_pressure(feature) - if pressure is not None: - out = out.split('(.*)')[0] + f'{pressure}pa' - - return lambda x: [out] + return cls._lookup(out, feature, handle_features) @classmethod def get_inputs_recursive(cls, feature, handle_features): @@ -1702,7 +1728,6 @@ def get_inputs_recursive(cls, feature, handle_features): for r in cls.get_inputs_recursive(f, handle_features): if r not in raw_features: raw_features.append(r) - return raw_features @classmethod diff --git a/sup3r/qa/qa.py b/sup3r/qa/qa.py index 8622520e5..c715f446c 100644 --- a/sup3r/qa/qa.py +++ b/sup3r/qa/qa.py @@ -32,24 +32,30 @@ class Sup3rQa: into a 2D raster dataset (e.g. no sparsifying of the meta data). """ - def __init__(self, source_file_paths, out_file_path, s_enhance, t_enhance, - temporal_coarsening_method, - features=None, - source_features=None, - output_names=None, - temporal_slice=slice(None), - target=None, - shape=None, - raster_file=None, - qa_fp=None, - bias_correct_method=None, - bias_correct_kwargs=None, - save_sources=True, - time_chunk_size=None, - cache_pattern=None, - overwrite_cache=False, - input_handler=None, - worker_kwargs=None): + def __init__( + self, + source_file_paths, + out_file_path, + s_enhance, + t_enhance, + temporal_coarsening_method, + features=None, + source_features=None, + output_names=None, + temporal_slice=slice(None), + target=None, + shape=None, + raster_file=None, + qa_fp=None, + bias_correct_method=None, + bias_correct_kwargs=None, + save_sources=True, + time_chunk_size=None, + cache_pattern=None, + overwrite_cache=False, + input_handler=None, + worker_kwargs=None, + ): """Parameters ---------- source_file_paths : list | str @@ -175,14 +181,19 @@ def __init__(self, source_file_paths, out_file_path, s_enhance, t_enhance, self.t_enhance = t_enhance self._t_meth = temporal_coarsening_method self._out_fp = out_file_path - self._features = (features if isinstance(features, (list, tuple)) - else [features]) - self._source_features = (source_features if - isinstance(source_features, (list, tuple)) - else [source_features]) - self._out_names = (output_names if - isinstance(output_names, (list, tuple)) - else [output_names]) + self._features = ( + features if isinstance(features, (list, tuple)) else [features] + ) + self._source_features = ( + source_features + if isinstance(source_features, (list, tuple)) + else [source_features] + ) + self._out_names = ( + output_names + if isinstance(output_names, (list, tuple)) + else [output_names] + ) self.qa_fp = qa_fp self.save_sources = save_sources self.output_handler = self.output_handler_class(self._out_fp) @@ -190,19 +201,22 @@ def __init__(self, source_file_paths, out_file_path, s_enhance, t_enhance, self.bias_correct_method = bias_correct_method self.bias_correct_kwargs = bias_correct_kwargs or {} - HandlerClass = get_input_handler_class(source_file_paths, - input_handler) - self.source_handler = HandlerClass(source_file_paths, - self.source_features_flat, - target=target, - shape=shape, - temporal_slice=temporal_slice, - raster_file=raster_file, - cache_pattern=cache_pattern, - time_chunk_size=time_chunk_size, - overwrite_cache=overwrite_cache, - val_split=0.0, - worker_kwargs=worker_kwargs) + HandlerClass = get_input_handler_class( + source_file_paths, input_handler + ) + self.source_handler = HandlerClass( + source_file_paths, + self.source_features_flat, + target=target, + shape=shape, + temporal_slice=temporal_slice, + raster_file=raster_file, + cache_pattern=cache_pattern, + time_chunk_size=time_chunk_size, + overwrite_cache=overwrite_cache, + val_split=0.0, + worker_kwargs=worker_kwargs, + ) def __enter__(self): return self @@ -225,8 +239,12 @@ def meta(self): pd.DataFrame """ lat_lon = self.source_handler.lat_lon - meta = pd.DataFrame({'latitude': lat_lon[..., 0].flatten(), - 'longitude': lat_lon[..., 1].flatten()}) + meta = pd.DataFrame( + { + 'latitude': lat_lon[..., 0].flatten(), + 'longitude': lat_lon[..., 1].flatten(), + } + ) return meta @property @@ -369,23 +387,32 @@ def bias_correct_source_data(self, data, lat_lon, source_feature): if 'time_index' in signature(method).parameters: feature_kwargs['time_index'] = self.time_index - if ('lr_padded_slice' in signature(method).parameters - and 'lr_padded_slice' not in feature_kwargs): + if ( + 'lr_padded_slice' in signature(method).parameters + and 'lr_padded_slice' not in feature_kwargs + ): feature_kwargs['lr_padded_slice'] = None - if ('temporal_avg' in signature(method).parameters - and 'temporal_avg' not in feature_kwargs): - msg = ('The kwarg "temporal_avg" was not provided in the bias ' - 'correction kwargs but is present in the bias ' - 'correction function "{}". If this is not set ' - 'appropriately, especially for monthly bias ' - 'correction, it could result in QA results that look ' - 'worse than they actually are.'.format(method)) + if ( + 'temporal_avg' in signature(method).parameters + and 'temporal_avg' not in feature_kwargs + ): + msg = ( + 'The kwarg "temporal_avg" was not provided in the bias ' + 'correction kwargs but is present in the bias ' + 'correction function "{}". If this is not set ' + 'appropriately, especially for monthly bias ' + 'correction, it could result in QA results that look ' + 'worse than they actually are.'.format(method) + ) logger.warning(msg) warn(msg) - logger.debug('Bias correcting source_feature "{}" using ' - 'function: {} with kwargs: {}' - .format(source_feature, method, feature_kwargs)) + logger.debug( + 'Bias correcting source_feature "{}" using ' + 'function: {} with kwargs: {}'.format( + source_feature, method, feature_kwargs + ) + ) data = method(data, lat_lon, **feature_kwargs) @@ -410,9 +437,10 @@ def get_source_dset(self, feature, source_feature): lat_lon = self.source_handler.lat_lon if 'windspeed' in feature and len(source_feature) == 2: u_feat, v_feat = source_feature - logger.info('For sup3r output feature "{}", retrieving u/v ' - 'components "{}" and "{}"' - .format(feature, u_feat, v_feat)) + logger.info( + 'For sup3r output feature "{}", retrieving u/v ' + 'components "{}" and "{}"'.format(feature, u_feat, v_feat) + ) u_idf = self.source_handler.features.index(u_feat) v_idf = self.source_handler.features.index(v_feat) u_true = self.source_handler.data[..., u_idf] @@ -423,8 +451,9 @@ def get_source_dset(self, feature, source_feature): else: idf = self.source_handler.features.index(source_feature) data_true = self.source_handler.data[..., idf] - data_true = self.bias_correct_source_data(data_true, lat_lon, - source_feature) + data_true = self.bias_correct_source_data( + data_true, lat_lon, source_feature + ) return data_true @@ -449,9 +478,11 @@ def get_dset_out(self, name): if self.output_type == 'nc': data = data.values elif self.output_type == 'h5': - shape = (len(self.time_index) * self.t_enhance, - int(self.lr_shape[0] * self.s_enhance), - int(self.lr_shape[1] * self.s_enhance)) + shape = ( + len(self.time_index) * self.t_enhance, + int(self.lr_shape[0] * self.s_enhance), + int(self.lr_shape[1] * self.s_enhance), + ) data = data.reshape(shape) # data always needs to be converted from (t, s1, s2) -> (s1, s2, t) @@ -478,21 +509,28 @@ def coarsen_data(self, idf, feature, data): A spatiotemporally coarsened copy of the input dataset, still with shape (spatial_1, spatial_2, temporal) """ - t_meth = (self._t_meth if isinstance(self._t_meth, str) - else self._t_meth[idf]) - - logger.info(f'Coarsening feature "{feature}" with {self.s_enhance}x ' - f'spatial averaging and "{t_meth}" {self.t_enhance}x ' - 'temporal averaging') - - data = spatial_coarsening(data, s_enhance=self.s_enhance, - obs_axis=False) + t_meth = ( + self._t_meth + if isinstance(self._t_meth, str) + else self._t_meth[idf] + ) + + logger.info( + f'Coarsening feature "{feature}" with {self.s_enhance}x ' + f'spatial averaging and "{t_meth}" {self.t_enhance}x ' + 'temporal averaging' + ) + + data = spatial_coarsening( + data, s_enhance=self.s_enhance, obs_axis=False + ) # t_coarse needs shape to be 5D: (obs, s1, s2, t, f) data = np.expand_dims(data, axis=0) data = np.expand_dims(data, axis=4) - data = temporal_coarsening(data, t_enhance=self.t_enhance, - method=t_meth) + data = temporal_coarsening( + data, t_enhance=self.t_enhance, method=t_meth + ) data = data[0] data = data[..., 0] @@ -510,7 +548,7 @@ def get_node_cmd(cls, config): initialize Sup3rQa and execute Sup3rQa.run() """ import_str = 'import time;\n' - import_str += 'from reV.pipeline.status import Status;\n' + import_str += 'from sup3r.pipeline import Status;\n' import_str += 'from rex import init_logger;\n' import_str += 'from sup3r.qa.qa import Sup3rQa;\n' @@ -519,19 +557,21 @@ def get_node_cmd(cls, config): log_file = config.get('log_file', None) log_level = config.get('log_level', 'INFO') - log_arg_str = (f'"sup3r", log_level="{log_level}"') + log_arg_str = f'"sup3r", log_level="{log_level}"' if log_file is not None: log_arg_str += f', log_file="{log_file}"' - cmd = (f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"qa = {qa_init_str};\n" - "qa.run();\n" - "t_elap = time.time() - t0;\n") + cmd = ( + f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"qa = {qa_init_str};\n" + "qa.run();\n" + "t_elap = time.time() - t0;\n" + ) cmd = BaseCLI.add_status_cmd(config, ModuleName.QA, cmd) - cmd += (";\'\n") + cmd += ";\'\n" return cmd.replace('\\', '/') @@ -574,10 +614,14 @@ def export(self, qa_fp, data, dset_name, dset_suffix=''): # transpose and flatten to typical h5 (time, space) dimensions data = np.transpose(data, axes=(2, 0, 1)).reshape(shape) - RexOutputs.add_dataset(qa_fp, dset_name, data, - dtype=attrs['dtype'], - chunks=attrs.get('chunks', None), - attrs=attrs) + RexOutputs.add_dataset( + qa_fp, + dset_name, + data, + dtype=attrs['dtype'], + chunks=attrs.get('chunks', None), + attrs=attrs, + ) def run(self): """Go through all datasets and get the error for the re-coarsened @@ -594,19 +638,24 @@ def run(self): errors = {} ziter = zip(self.features, self.source_features, self.output_names) for idf, (feature, source_feature, dset_out) in enumerate(ziter): - logger.info('Running QA on dataset {} of {} for "{}" ' - 'corresponding to source feature "{}"' - .format(idf + 1, len(self.features), feature, - source_feature)) + logger.info( + 'Running QA on dataset {} of {} for "{}" ' + 'corresponding to source feature "{}"'.format( + idf + 1, len(self.features), feature, source_feature + ) + ) data_syn = self.get_dset_out(feature) data_syn = self.coarsen_data(idf, feature, data_syn) data_true = self.get_source_dset(feature, source_feature) if data_syn.shape != data_true.shape: - msg = ('Sup3rQa failed while trying to inspect the "{}" ' - 'feature. The source low-res data had shape {} ' - 'while the re-coarsened synthetic data had shape {}.' - .format(feature, data_true.shape, data_syn.shape)) + msg = ( + 'Sup3rQa failed while trying to inspect the "{}" feature. ' + 'The source low-res data had shape {} while the ' + 're-coarsened synthetic data had shape {}.'.format( + feature, data_true.shape, data_syn.shape + ) + ) logger.error(msg) raise RuntimeError(msg) diff --git a/sup3r/qa/stats.py b/sup3r/qa/stats.py index d82d5a53c..d71504010 100644 --- a/sup3r/qa/stats.py +++ b/sup3r/qa/stats.py @@ -115,8 +115,10 @@ def export(self, qa_fp, data): with open(qa_fp, 'wb') as f: pickle.dump(data, f, protocol=4) else: - logger.info(f'{qa_fp} already exists. Delete file or run with ' - 'overwrite_stats=True.') + logger.info( + f'{qa_fp} already exists. Delete file or run with ' + 'overwrite_stats=True.' + ) @classmethod def get_node_cmd(cls, config): @@ -130,7 +132,7 @@ def get_node_cmd(cls, config): initialize Sup3rStats and execute Sup3rStats.run() """ import_str = 'import time;\n' - import_str += 'from reV.pipeline.status import Status;\n' + import_str += 'from sup3r.pipeline import Status;\n' import_str += 'from rex import init_logger;\n' import_str += f'from sup3r.qa.stats import {cls.__name__};\n' @@ -139,19 +141,21 @@ def get_node_cmd(cls, config): log_file = config.get('log_file', None) log_level = config.get('log_level', 'INFO') - log_arg_str = (f'"sup3r", log_level="{log_level}"') + log_arg_str = f'"sup3r", log_level="{log_level}"' if log_file is not None: log_arg_str += f', log_file="{log_file}"' - cmd = (f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"qa = {qa_init_str};\n" - "qa.run();\n" - "t_elap = time.time() - t0;\n") + cmd = ( + f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"qa = {qa_init_str};\n" + "qa.run();\n" + "t_elap = time.time() - t0;\n" + ) cmd = BaseCLI.add_status_cmd(config, ModuleName.STATS, cmd) - cmd += (";\'\n") + cmd += ";\'\n" return cmd.replace('\\', '/') @@ -159,13 +163,27 @@ def get_node_cmd(cls, config): class Sup3rStatsCompute(Sup3rStatsBase): """Base class for computing stats on input data arrays""" - def __init__(self, input_data=None, s_enhance=1, t_enhance=1, - compute_features=None, input_features=None, - cache_pattern=None, overwrite_cache=False, - overwrite_stats=True, get_interp=False, - include_stats=None, max_values=None, smoothing=None, - spatial_res=None, temporal_res=None, n_bins=40, qa_fp=None, - interp_dists=True, time_chunk_size=100): + def __init__( + self, + input_data=None, + s_enhance=1, + t_enhance=1, + compute_features=None, + input_features=None, + cache_pattern=None, + overwrite_cache=False, + overwrite_stats=True, + get_interp=False, + include_stats=None, + max_values=None, + smoothing=None, + spatial_res=None, + temporal_res=None, + n_bins=40, + qa_fp=None, + interp_dists=True, + time_chunk_size=100, + ): """Parameters ---------- input_data : ndarray @@ -229,8 +247,10 @@ def __init__(self, input_data=None, s_enhance=1, t_enhance=1, msg = 'Preparing to compute statistics.' if input_data is None: - msg = ('Received empty input array. Skipping statistics ' - 'computations.') + msg = ( + 'Received empty input array. Skipping statistics ' + 'computations.' + ) logger.info(msg) self.max_values = max_values or {} @@ -238,8 +258,12 @@ def __init__(self, input_data=None, s_enhance=1, t_enhance=1, self.direct_max = self.max_values.get(self._DIRECT, None) self.time_derivative_max = self.max_values.get(self._DY_DT, None) self.gradient_max = self.max_values.get(self._DY_DX, None) - self.include_stats = include_stats or [self._DIRECT, self._DY_DX, - self._DY_DT, self._FFT_K] + self.include_stats = include_stats or [ + self._DIRECT, + self._DY_DX, + self._DY_DT, + self._FFT_K, + ] self.s_enhance = s_enhance self.t_enhance = t_enhance self._features = compute_features @@ -351,8 +375,9 @@ def get_fluctuation(var): (spatial_1, spatial_2, temporal) """ avg = np.mean(var, axis=-1) - return var - np.repeat(np.expand_dims(avg, axis=-1), var.shape[-1], - axis=-1) + return var - np.repeat( + np.expand_dims(avg, axis=-1), var.shape[-1], axis=-1 + ) def interpolate_data(self, feature, low_res): """Get interpolated low res field @@ -382,13 +407,16 @@ def interpolate_data(self, feature, low_res): slices = [slice(s[0], s[-1] + 1) for s in slices] for i, s in enumerate(slices): - chunks.append(st_interp(low_res[..., s], self.s_enhance, - self.t_enhance)) + chunks.append( + st_interp(low_res[..., s], self.s_enhance, self.t_enhance) + ) mem = psutil.virtual_memory() - logger.info(f'Finished interpolating {i+1} / {len(slices)} ' - 'chunks. Current memory usage is ' - f'{mem.used / 1e9:.3f} GB out of ' - f'{mem.total / 1e9:.3f} GB total.') + logger.info( + f'Finished interpolating {i+1} / {len(slices)} ' + 'chunks. Current memory usage is ' + f'{mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.' + ) var_itp = np.concatenate(chunks, axis=-1) if 'direction' in feature: @@ -423,8 +451,9 @@ def check_return_cache(self, feature, shape): shape_str = f'{shape[0]}x{shape[1]}x{shape[2]}' if self.cache_pattern is not None: file_name = self.cache_pattern.replace('{shape}', f'{shape_str}') - file_name = file_name.replace('{feature}', - f'{feature.lower()}_interp') + file_name = file_name.replace( + '{feature}', f'{feature.lower()}_interp' + ) if file_name is not None and os.path.exists(file_name): var_itp = self.load_cache(file_name) return var_itp, file_name @@ -461,8 +490,11 @@ def _compute_dist_type(self, var, stat_type, interp=False, period=None): """ tmp = var.copy() if 'mean' in stat_type: - tmp = (np.mean(tmp, axis=-1) if 'time' not in stat_type - else np.mean(tmp, axis=(0, 1))) + tmp = ( + np.mean(tmp, axis=-1) + if 'time' not in stat_type + else np.mean(tmp, axis=(0, 1)) + ) if self._DIRECT in stat_type: max_val = self.direct_max method = direct_dist @@ -470,18 +502,30 @@ def _compute_dist_type(self, var, stat_type, interp=False, period=None): elif self._DY_DX in stat_type: max_val = self.gradient_max method = gradient_dist - scale = (self.spatial_res if not interp - else self.spatial_res / self.s_enhance) + scale = ( + self.spatial_res + if not interp + else self.spatial_res / self.s_enhance + ) elif self._DY_DT in stat_type: max_val = self.time_derivative_max method = time_derivative_dist - scale = (self.temporal_res if not interp - else self.temporal_res / self.t_enhance) + scale = ( + self.temporal_res + if not interp + else self.temporal_res / self.t_enhance + ) else: return None - kwargs = dict(var=tmp, diff_max=max_val, bins=self.n_bins, scale=scale, - interpolate=self.interp_dists, period=period) + kwargs = dict( + var=tmp, + diff_max=max_val, + bins=self.n_bins, + scale=scale, + interpolate=self.interp_dists, + period=period, + ) return method(**kwargs) def get_stats(self, var, interp=False, period=None): @@ -507,17 +551,19 @@ def get_stats(self, var, interp=False, period=None): """ stats_dict = {} for stat_type in self.include_stats: - if 'spectrum' in stat_type: out = self._compute_spectra_type(var, stat_type, interp=interp) else: - out = self._compute_dist_type(var, stat_type, interp=interp, - period=period) + out = self._compute_dist_type( + var, stat_type, interp=interp, period=period + ) if out is not None: mem = psutil.virtual_memory() - logger.info(f'Computed {stat_type}. Current memory usage is ' - f'{mem.used / 1e9:.3f} GB out of ' - f'{mem.total / 1e9:.3f} GB total.') + logger.info( + f'Computed {stat_type}. Current memory usage is ' + f'{mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.' + ) stats_dict[stat_type] = out return stats_dict @@ -543,9 +589,11 @@ def get_feature_data(self, feature): lower_features = [f.lower() for f in self.input_features] uidx = lower_features.index(f'u_{height}m') vidx = lower_features.index(f'v_{height}m') - out = vorticity_calc(self.source_data[..., uidx], - self.source_data[..., vidx], - scale=self.spatial_res) + out = vorticity_calc( + self.source_data[..., uidx], + self.source_data[..., vidx], + scale=self.spatial_res, + ) else: idx = self.input_features.index(feature) out = self.source_data[..., idx] @@ -601,8 +649,10 @@ def run(self): source, interp = self.get_feature_stats(feature) mem = psutil.virtual_memory() - logger.info(f'Current memory usage is {mem.used / 1e9:.3f} ' - f'GB out of {mem.total / 1e9:.3f} GB total.') + logger.info( + f'Current memory usage is {mem.used / 1e9:.3f} ' + f'GB out of {mem.total / 1e9:.3f} GB total.' + ) if self.source_data is not None: source_stats[feature] = source @@ -622,16 +672,33 @@ def run(self): class Sup3rStatsSingle(Sup3rStatsCompute): """Base class for doing statistical QA on single file set.""" - def __init__(self, source_file_paths=None, - s_enhance=1, t_enhance=1, features=None, - temporal_slice=slice(None), target=None, shape=None, - raster_file=None, time_chunk_size=None, - cache_pattern=None, overwrite_cache=False, - overwrite_stats=False, source_handler=None, - worker_kwargs=None, get_interp=False, include_stats=None, - max_values=None, smoothing=None, coarsen=False, - spatial_res=None, temporal_res=None, n_bins=40, - max_delta=10, qa_fp=None): + def __init__( + self, + source_file_paths=None, + s_enhance=1, + t_enhance=1, + features=None, + temporal_slice=slice(None), + target=None, + shape=None, + raster_file=None, + time_chunk_size=None, + cache_pattern=None, + overwrite_cache=False, + overwrite_stats=False, + source_handler=None, + worker_kwargs=None, + get_interp=False, + include_stats=None, + max_values=None, + smoothing=None, + coarsen=False, + spatial_res=None, + temporal_res=None, + n_bins=40, + max_delta=10, + qa_fp=None, + ): """Parameters ---------- source_file_paths : list | str @@ -742,8 +809,10 @@ def __init__(self, source_file_paths=None, File path for saving statistics. Only .pkl supported. """ - logger.info('Initializing Sup3rStatsSingle and retrieving source data' - f' for features={features}.') + logger.info( + 'Initializing Sup3rStatsSingle and retrieving source data' + f' for features={features}.' + ) worker_kwargs = worker_kwargs or {} max_workers = worker_kwargs.get('max_workers', None) @@ -778,30 +847,39 @@ def __init__(self, source_file_paths=None, self._k_range = None self._f_range = None - source_handler_kwargs = dict(target=target, - shape=shape, - temporal_slice=temporal_slice, - raster_file=raster_file, - cache_pattern=cache_pattern, - time_chunk_size=time_chunk_size, - overwrite_cache=overwrite_cache, - worker_kwargs=worker_kwargs, - max_delta=max_delta) - self.source_data = self.get_source_data(source_file_paths, - source_handler_kwargs) - - super().__init__(self.source_data, s_enhance=s_enhance, - t_enhance=t_enhance, - compute_features=self.compute_features, - input_features=self.input_features, - cache_pattern=cache_pattern, - overwrite_cache=overwrite_cache, - overwrite_stats=overwrite_stats, - get_interp=get_interp, include_stats=include_stats, - max_values=max_values, smoothing=smoothing, - spatial_res=spatial_res, - temporal_res=self.temporal_res, n_bins=n_bins, - qa_fp=qa_fp) + source_handler_kwargs = dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + raster_file=raster_file, + cache_pattern=cache_pattern, + time_chunk_size=time_chunk_size, + overwrite_cache=overwrite_cache, + worker_kwargs=worker_kwargs, + max_delta=max_delta, + ) + self.source_data = self.get_source_data( + source_file_paths, source_handler_kwargs + ) + + super().__init__( + self.source_data, + s_enhance=s_enhance, + t_enhance=t_enhance, + compute_features=self.compute_features, + input_features=self.input_features, + cache_pattern=cache_pattern, + overwrite_cache=overwrite_cache, + overwrite_stats=overwrite_stats, + get_interp=get_interp, + include_stats=include_stats, + max_values=max_values, + smoothing=smoothing, + spatial_res=spatial_res, + temporal_res=self.temporal_res, + n_bins=n_bins, + qa_fp=qa_fp, + ) def close(self): """Close any open file handlers""" @@ -822,8 +900,10 @@ def source_type(self): ftype = get_source_type(self.source_file_paths) if ftype not in ('nc', 'h5'): - msg = ('Did not recognize source file type: ' - f'{self.source_file_paths}') + msg = ( + 'Did not recognize source file type: ' + f'{self.source_file_paths}' + ) logger.error(msg) raise TypeError(msg) return ftype @@ -831,8 +911,9 @@ def source_type(self): @property def source_handler_class(self): """Get source handler class""" - HandlerClass = get_input_handler_class(self.source_file_paths, - self._source_handler_class) + HandlerClass = get_input_handler_class( + self.source_file_paths, self._source_handler_class + ) return HandlerClass @property @@ -863,16 +944,18 @@ def get_source_data(self, file_paths, handler_kwargs=None): if file_paths is None: return None - self._source_handler = self.source_handler_class(file_paths, - self.input_features, - val_split=0.0, - **handler_kwargs) + self._source_handler = self.source_handler_class( + file_paths, self.input_features, val_split=0.0, **handler_kwargs + ) self._source_handler.load_cached_data() if self.coarsen: - logger.info('Coarsening data with shape=' - f'{self._source_handler.data.shape}') + logger.info( + 'Coarsening data with shape=' + f'{self._source_handler.data.shape}' + ) self._source_handler.data = self.coarsen_data( - self._source_handler.data, smoothing=self.smoothing) + self._source_handler.data, smoothing=self.smoothing + ) logger.info(f'Coarsened shape={self._source_handler.data.shape}') return self._source_handler.data @@ -897,8 +980,12 @@ def meta(self): ------- pd.DataFrame """ - meta = pd.DataFrame({'latitude': self.lat_lon[..., 0].flatten(), - 'longitude': self.lat_lon[..., 1].flatten()}) + meta = pd.DataFrame( + { + 'latitude': self.lat_lon[..., 0].flatten(), + 'longitude': self.lat_lon[..., 1].flatten(), + } + ) return meta @property @@ -919,8 +1006,9 @@ def input_features(self): ------- list """ - self._input_features = [f for f in self.compute_features if 'vorticity' - not in f] + self._input_features = [ + f for f in self.compute_features if 'vorticity' not in f + ] for feature in self.compute_features: if 'vorticity' in feature: height = Feature.get_height(feature) @@ -935,8 +1023,9 @@ def input_features(self): @input_features.setter def input_features(self, input_features): """Set input features""" - self._input_features = [f for f in input_features if 'vorticity' - not in f] + self._input_features = [ + f for f in input_features if 'vorticity' not in f + ] for feature in input_features: if 'vorticity' in feature: height = Feature.get_height(feature) @@ -972,9 +1061,9 @@ def coarsen_data(self, data, smoothing=None): """ n_lats = self.s_enhance * (data.shape[0] // self.s_enhance) n_lons = self.s_enhance * (data.shape[1] // self.s_enhance) - data = spatial_coarsening(data[:n_lats, :n_lons], - s_enhance=self.s_enhance, - obs_axis=False) + data = spatial_coarsening( + data[:n_lats, :n_lons], s_enhance=self.s_enhance, obs_axis=False + ) # t_coarse needs shape to be 5D: (obs, s1, s2, t, f) data = np.expand_dims(data, axis=0) @@ -984,9 +1073,9 @@ def coarsen_data(self, data, smoothing=None): if smoothing is not None: for i in range(data.shape[-1]): for t in range(data.shape[-2]): - data[..., t, i] = gaussian_filter(data[..., t, i], - smoothing, - mode='nearest') + data[..., t, i] = gaussian_filter( + data[..., t, i], smoothing, mode='nearest' + ) return data @@ -996,17 +1085,39 @@ class Sup3rStatsMulti(Sup3rStatsBase): high resolution corresponding to the low resolution input. This class will provide statistics used to compare all these datasets.""" - def __init__(self, lr_file_paths=None, synth_file_paths=None, - hr_file_paths=None, s_enhance=1, t_enhance=1, features=None, - lr_t_slice=slice(None), synth_t_slice=slice(None), - hr_t_slice=slice(None), target=None, shape=None, - raster_file=None, qa_fp=None, time_chunk_size=None, - cache_pattern=None, overwrite_cache=False, - overwrite_synth_cache=False, overwrite_stats=False, - source_handler=None, output_handler=None, worker_kwargs=None, - get_interp=False, include_stats=None, max_values=None, - smoothing=None, spatial_res=None, temporal_res=None, - n_bins=40, max_delta=10, save_fig_data=False): + def __init__( + self, + lr_file_paths=None, + synth_file_paths=None, + hr_file_paths=None, + s_enhance=1, + t_enhance=1, + features=None, + lr_t_slice=slice(None), + synth_t_slice=slice(None), + hr_t_slice=slice(None), + target=None, + shape=None, + raster_file=None, + qa_fp=None, + time_chunk_size=None, + cache_pattern=None, + overwrite_cache=False, + overwrite_synth_cache=False, + overwrite_stats=False, + source_handler=None, + output_handler=None, + worker_kwargs=None, + get_interp=False, + include_stats=None, + max_values=None, + smoothing=None, + spatial_res=None, + temporal_res=None, + n_bins=40, + max_delta=10, + save_fig_data=False, + ): """Parameters ---------- lr_file_paths : list | str @@ -1136,8 +1247,10 @@ def __init__(self, lr_file_paths=None, synth_file_paths=None, Number of bins to use for constructing probability distributions """ - logger.info('Initializing Sup3rStatsMulti and retrieving source data' - f' for features={features}.') + logger.info( + 'Initializing Sup3rStatsMulti and retrieving source data' + f' for features={features}.' + ) self.qa_fp = qa_fp self.overwrite_stats = overwrite_stats @@ -1146,20 +1259,29 @@ def __init__(self, lr_file_paths=None, synth_file_paths=None, # get low res and interp stats logger.info('Retrieving source data for low-res and interp stats') - kwargs = dict(source_file_paths=lr_file_paths, - s_enhance=s_enhance, t_enhance=t_enhance, - features=features, temporal_slice=lr_t_slice, - target=target, shape=shape, - time_chunk_size=time_chunk_size, - cache_pattern=cache_pattern, - overwrite_cache=overwrite_cache, - overwrite_stats=overwrite_stats, - source_handler=source_handler, - worker_kwargs=worker_kwargs, - get_interp=get_interp, include_stats=include_stats, - max_values=max_values, smoothing=None, - spatial_res=spatial_res, temporal_res=temporal_res, - n_bins=n_bins, max_delta=max_delta) + kwargs = dict( + source_file_paths=lr_file_paths, + s_enhance=s_enhance, + t_enhance=t_enhance, + features=features, + temporal_slice=lr_t_slice, + target=target, + shape=shape, + time_chunk_size=time_chunk_size, + cache_pattern=cache_pattern, + overwrite_cache=overwrite_cache, + overwrite_stats=overwrite_stats, + source_handler=source_handler, + worker_kwargs=worker_kwargs, + get_interp=get_interp, + include_stats=include_stats, + max_values=max_values, + smoothing=None, + spatial_res=spatial_res, + temporal_res=temporal_res, + n_bins=n_bins, + max_delta=max_delta, + ) self.lr_stats = Sup3rStatsSingle(**kwargs) if self.lr_stats.source_data is not None: @@ -1170,62 +1292,99 @@ def __init__(self, lr_file_paths=None, synth_file_paths=None, # get high res stats shape = (self.lr_shape[0] * s_enhance, self.lr_shape[1] * s_enhance) - logger.info('Retrieving source data for high-res stats with ' - f'shape={shape}') - tmp_raster = (raster_file if raster_file is None - else raster_file.replace('.txt', '_hr.txt')) - tmp_cache = (cache_pattern if cache_pattern is None - else cache_pattern.replace('.pkl', '_hr.pkl')) + logger.info( + 'Retrieving source data for high-res stats with ' f'shape={shape}' + ) + tmp_raster = ( + raster_file + if raster_file is None + else raster_file.replace('.txt', '_hr.txt') + ) + tmp_cache = ( + cache_pattern + if cache_pattern is None + else cache_pattern.replace('.pkl', '_hr.pkl') + ) hr_spatial_res = spatial_res or 1 hr_spatial_res /= s_enhance hr_temporal_res = temporal_res or 1 hr_temporal_res /= t_enhance - kwargs_new = dict(source_file_paths=hr_file_paths, - s_enhance=1, t_enhance=1, - shape=shape, target=target, - spatial_res=hr_spatial_res, - temporal_res=hr_temporal_res, - get_interp=False, source_handler=source_handler, - cache_pattern=tmp_cache, - temporal_slice=hr_t_slice) + kwargs_new = dict( + source_file_paths=hr_file_paths, + s_enhance=1, + t_enhance=1, + shape=shape, + target=target, + spatial_res=hr_spatial_res, + temporal_res=hr_temporal_res, + get_interp=False, + source_handler=source_handler, + cache_pattern=tmp_cache, + temporal_slice=hr_t_slice, + ) kwargs_hr = kwargs.copy() kwargs_hr.update(kwargs_new) self.hr_stats = Sup3rStatsSingle(**kwargs_hr) # get synthetic stats shape = (self.lr_shape[0] * s_enhance, self.lr_shape[1] * s_enhance) - logger.info('Retrieving source data for synthetic stats with ' - f'shape={shape}') - tmp_raster = (raster_file if raster_file is None - else raster_file.replace('.txt', '_synth.txt')) - tmp_cache = (cache_pattern if cache_pattern is None - else cache_pattern.replace('.pkl', '_synth.pkl')) - kwargs_new = dict(source_file_paths=synth_file_paths, - s_enhance=1, t_enhance=1, - shape=shape, target=target, - spatial_res=hr_spatial_res, - temporal_res=hr_temporal_res, - get_interp=False, source_handler=output_handler, - raster_file=tmp_raster, cache_pattern=tmp_cache, - overwrite_cache=(overwrite_synth_cache), - temporal_slice=synth_t_slice) + logger.info( + 'Retrieving source data for synthetic stats with ' f'shape={shape}' + ) + tmp_raster = ( + raster_file + if raster_file is None + else raster_file.replace('.txt', '_synth.txt') + ) + tmp_cache = ( + cache_pattern + if cache_pattern is None + else cache_pattern.replace('.pkl', '_synth.pkl') + ) + kwargs_new = dict( + source_file_paths=synth_file_paths, + s_enhance=1, + t_enhance=1, + shape=shape, + target=target, + spatial_res=hr_spatial_res, + temporal_res=hr_temporal_res, + get_interp=False, + source_handler=output_handler, + raster_file=tmp_raster, + cache_pattern=tmp_cache, + overwrite_cache=(overwrite_synth_cache), + temporal_slice=synth_t_slice, + ) kwargs_synth = kwargs.copy() kwargs_synth.update(kwargs_new) self.synth_stats = Sup3rStatsSingle(**kwargs_synth) # get coarse stats logger.info('Retrieving source data for coarse stats') - tmp_raster = (raster_file if raster_file is None - else raster_file.replace('.txt', '_coarse.txt')) - tmp_cache = (cache_pattern if cache_pattern is None - else cache_pattern.replace('.pkl', '_coarse.pkl')) - kwargs_new = dict(source_file_paths=hr_file_paths, - spatial_res=spatial_res, temporal_res=temporal_res, - target=target, shape=shape, smoothing=smoothing, - coarsen=True, get_interp=False, - source_handler=output_handler, - cache_pattern=tmp_cache, - temporal_slice=hr_t_slice) + tmp_raster = ( + raster_file + if raster_file is None + else raster_file.replace('.txt', '_coarse.txt') + ) + tmp_cache = ( + cache_pattern + if cache_pattern is None + else cache_pattern.replace('.pkl', '_coarse.pkl') + ) + kwargs_new = dict( + source_file_paths=hr_file_paths, + spatial_res=spatial_res, + temporal_res=temporal_res, + target=target, + shape=shape, + smoothing=smoothing, + coarsen=True, + get_interp=False, + source_handler=output_handler, + cache_pattern=tmp_cache, + temporal_slice=hr_t_slice, + ) kwargs_coarse = kwargs.copy() kwargs_coarse.update(kwargs_new) self.coarse_stats = Sup3rStatsSingle(**kwargs_coarse) @@ -1236,20 +1395,30 @@ def export_fig_data(self): fig_data = {} if self.synth_stats.source_data is not None: fig_data.update( - {'time_index': self.synth_stats.time_index, - 'synth': self.synth_stats.get_feature_data(feature), - 'synth_grid': self.synth_stats.source_handler.lat_lon}) + { + 'time_index': self.synth_stats.time_index, + 'synth': self.synth_stats.get_feature_data(feature), + 'synth_grid': self.synth_stats.source_handler.lat_lon, + } + ) if self.lr_stats.source_data is not None: fig_data.update( - {'low_res': self.lr_stats.get_feature_data(feature), - 'low_res_grid': self.lr_stats.source_handler.lat_lon}) + { + 'low_res': self.lr_stats.get_feature_data(feature), + 'low_res_grid': self.lr_stats.source_handler.lat_lon, + } + ) if self.hr_stats.source_data is not None: fig_data.update( - {'high_res': self.hr_stats.get_feature_data(feature), - 'high_res_grid': self.hr_stats.source_handler.lat_lon}) + { + 'high_res': self.hr_stats.get_feature_data(feature), + 'high_res_grid': self.hr_stats.source_handler.lat_lon, + } + ) if self.coarse_stats.source_data is not None: fig_data.update( - {'coarse': self.coarse_stats.get_feature_data(feature)}) + {'coarse': self.coarse_stats.get_feature_data(feature)} + ) file_name = self.qa_fp.replace('.pkl', f'_{feature}_compare.pkl') with open(file_name, 'wb') as fp: @@ -1258,8 +1427,12 @@ def export_fig_data(self): def close(self): """Close any open file handlers""" - stats = [self.lr_stats, self.hr_stats, self.synth_stats, - self.coarse_stats] + stats = [ + self.lr_stats, + self.hr_stats, + self.synth_stats, + self.coarse_stats, + ] for s_handle in stats: s_handle.close() @@ -1282,13 +1455,15 @@ def run(self): if lr_stats['interp']: stats['interp'] = lr_stats['interp'] if self.synth_stats.source_data is not None: - logger.info('Computing statistics on synthetic high-resolution ' - 'dataset.') + logger.info( + 'Computing statistics on synthetic high-resolution dataset.' + ) synth_stats = self.synth_stats.run() stats['synth'] = synth_stats['source'] if self.coarse_stats.source_data is not None: - logger.info('Computing statistics on coarsened low-resolution ' - 'dataset.') + logger.info( + 'Computing statistics on coarsened low-resolution dataset.' + ) coarse_stats = self.coarse_stats.run() stats['coarse'] = coarse_stats['source'] if self.hr_stats.source_data is not None: diff --git a/sup3r/qa/visual_qa.py b/sup3r/qa/visual_qa.py index 37cc2edcf..06ee516ef 100644 --- a/sup3r/qa/visual_qa.py +++ b/sup3r/qa/visual_qa.py @@ -1,29 +1,37 @@ # -*- coding: utf-8 -*- """Module to plot feature output from forward passes for visual inspection""" -import numpy as np -import matplotlib.pyplot as plt -import logging import glob -from datetime import datetime as dt +import logging import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime as dt +import matplotlib.pyplot as plt +import numpy as np import rex from rex.utilities.fun_utils import get_fun_call_str -from concurrent.futures import ThreadPoolExecutor, as_completed from sup3r.utilities import ModuleName from sup3r.utilities.cli import BaseCLI - logger = logging.getLogger(__name__) class Sup3rVisualQa: """Module to plot features for visual qa""" - def __init__(self, file_paths, out_pattern, features, time_step=10, - spatial_slice=slice(None), source_handler_class=None, - max_workers=None, overwrite=False, **kwargs): + def __init__( + self, + file_paths, + out_pattern, + features, + time_step=10, + spatial_slice=slice(None), + source_handler_class=None, + max_workers=None, + overwrite=False, + **kwargs, + ): """ Parameters ---------- @@ -61,10 +69,16 @@ def __init__(self, file_paths, out_pattern, features, time_step=10, self.features = features self.out_pattern = out_pattern self.time_step = time_step - self.spatial_slice = (spatial_slice if isinstance(spatial_slice, slice) - else slice(*spatial_slice)) - self.file_paths = (file_paths if isinstance(file_paths, list) - else glob.glob(file_paths)) + self.spatial_slice = ( + spatial_slice + if isinstance(spatial_slice, slice) + else slice(*spatial_slice) + ) + self.file_paths = ( + file_paths + if isinstance(file_paths, list) + else glob.glob(file_paths) + ) self.max_workers = max_workers self.kwargs = kwargs self.res_handler = source_handler_class or 'MultiFileResource' @@ -72,16 +86,18 @@ def __init__(self, file_paths, out_pattern, features, time_step=10, self.overwrite = overwrite if not os.path.exists(os.path.dirname(out_pattern)): os.makedirs(os.path.dirname(out_pattern), exist_ok=True) - logger.info('Initializing Sup3rVisualQa with ' - f'file_paths={self.file_paths}, ' - f'out_pattern={self.out_pattern}, ' - f'features={self.features}, ' - f'time_step={self.time_step}, ' - f'spatial_slice={self.spatial_slice}, ' - f'source_handler_class={self.res_handler}, ' - f'max_workers={max_workers}, ' - f'overwrite={self.overwrite}, ' - f'kwargs={kwargs}.') + logger.info( + 'Initializing Sup3rVisualQa with ' + f'file_paths={self.file_paths}, ' + f'out_pattern={self.out_pattern}, ' + f'features={self.features}, ' + f'time_step={self.time_step}, ' + f'spatial_slice={self.spatial_slice}, ' + f'source_handler_class={self.res_handler}, ' + f'max_workers={max_workers}, ' + f'overwrite={self.overwrite}, ' + f'kwargs={kwargs}.' + ) def run(self): """ @@ -91,19 +107,22 @@ def run(self): """ with self.res_handler(self.file_paths) as res: time_index = res.time_index - n_files = len(time_index[::self.time_step]) + n_files = len(time_index[:: self.time_step]) time_slices = np.array_split(np.arange(len(time_index)), n_files) time_slices = [slice(s[0], s[-1] + 1) for s in time_slices] if self.max_workers == 1: - self._serial_figure_plots(res, time_index, time_slices, - self.spatial_slice) + self._serial_figure_plots( + res, time_index, time_slices, self.spatial_slice + ) else: - self._parallel_figure_plots(res, time_index, time_slices, - self.spatial_slice) + self._parallel_figure_plots( + res, time_index, time_slices, self.spatial_slice + ) - def _serial_figure_plots(self, res, time_index, time_slices, - spatial_slice): + def _serial_figure_plots( + self, res, time_index, time_slices, spatial_slice + ): """Plot figures in parallel with max_workers=self.workers Parameters @@ -119,13 +138,16 @@ def _serial_figure_plots(self, res, time_index, time_slices, """ for feature in self.features: for i, t_slice in enumerate(time_slices): - out_file = self.out_pattern.format(feature=feature, - index=str(i).zfill(8)) - self.plot_figure(res, time_index, feature, t_slice, - spatial_slice, out_file) + out_file = self.out_pattern.format( + feature=feature, index=str(i).zfill(8) + ) + self.plot_figure( + res, time_index, feature, t_slice, spatial_slice, out_file + ) - def _parallel_figure_plots(self, res, time_index, time_slices, - spatial_slice): + def _parallel_figure_plots( + self, res, time_index, time_slices, spatial_slice + ): """Plot figures in parallel with max_workers=self.workers Parameters @@ -145,27 +167,36 @@ def _parallel_figure_plots(self, res, time_index, time_slices, with ThreadPoolExecutor(max_workers=self.max_workers) as exe: for feature in self.features: for i, t_slice in enumerate(time_slices): - out_file = self.out_pattern.format(feature=feature, - index=str(i).zfill(8)) - future = exe.submit(self.plot_figure, res, time_index, - feature, t_slice, spatial_slice, - out_file) + out_file = self.out_pattern.format( + feature=feature, index=str(i).zfill(8) + ) + future = exe.submit( + self.plot_figure, + res, + time_index, + feature, + t_slice, + spatial_slice, + out_file, + ) futures[future] = out_file - logger.info(f'Started plotting {n_files} files ' - f'in {dt.now() - now}.') + logger.info( + f'Started plotting {n_files} files ' f'in {dt.now() - now}.' + ) for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = (f'Error making plot {futures[future]}.') + msg = f'Error making plot {futures[future]}.' logger.exception(msg) raise RuntimeError(msg) from e logger.debug(f'{i+1} out of {n_files} plots created.') - def plot_figure(self, res, time_index, feature, t_slice, s_slice, - out_file): + def plot_figure( + self, res, time_index, feature, t_slice, s_slice, out_file + ): """Plot temporal average for the given feature and with the time range specified by t_slice @@ -185,19 +216,26 @@ def plot_figure(self, res, time_index, feature, t_slice, s_slice, Name of the output plot file """ if not self.overwrite and os.path.exists(out_file): - logger.info(f'{out_file} already exists and overwrite=' - f'{self.overwrite}. Skipping this plot.') + logger.info( + f'{out_file} already exists and overwrite=' + f'{self.overwrite}. Skipping this plot.' + ) return start_time = time_index[t_slice.start] stop_time = time_index[t_slice.stop - 1] - logger.info(f'Plotting time average for {feature} from ' - f'{start_time} to {stop_time}.') + logger.info( + f'Plotting time average for {feature} from ' + f'{start_time} to {stop_time}.' + ) fig = plt.figure() title = f'{feature}: {start_time} - {stop_time}' plt.suptitle(title) - plt.scatter(res.meta.longitude[s_slice], res.meta.latitude[s_slice], - c=np.mean(res[feature, t_slice, s_slice], axis=0), - **self.kwargs) + plt.scatter( + res.meta.longitude[s_slice], + res.meta.latitude[s_slice], + c=np.mean(res[feature, t_slice, s_slice], axis=0), + **self.kwargs, + ) plt.colorbar() fig.savefig(out_file) plt.close() @@ -215,7 +253,7 @@ def get_node_cmd(cls, config): initialize Sup3rVisualQa and execute Sup3rVisualQa.run() """ import_str = 'import time;\n' - import_str += 'from reV.pipeline.status import Status;\n' + import_str += 'from sup3r.pipeline import Status;\n' import_str += 'from rex import init_logger;\n' import_str += 'from sup3r.qa.visual_qa import Sup3rVisualQa;\n' @@ -224,18 +262,20 @@ def get_node_cmd(cls, config): log_file = config.get('log_file', None) log_level = config.get('log_level', 'INFO') - log_arg_str = (f'"sup3r", log_level="{log_level}"') + log_arg_str = f'"sup3r", log_level="{log_level}"' if log_file is not None: log_arg_str += f', log_file="{log_file}"' - cmd = (f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"qa = {qa_init_str};\n" - "qa.run();\n" - "t_elap = time.time() - t0;\n") + cmd = ( + f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"qa = {qa_init_str};\n" + "qa.run();\n" + "t_elap = time.time() - t0;\n" + ) cmd = BaseCLI.add_status_cmd(config, ModuleName.VISUAL_QA, cmd) - cmd += (";\'\n") + cmd += ";\'\n" return cmd.replace('\\', '/') diff --git a/sup3r/solar/solar.py b/sup3r/solar/solar.py index 5c6570130..2a16059a5 100644 --- a/sup3r/solar/solar.py +++ b/sup3r/solar/solar.py @@ -29,8 +29,16 @@ class Solar: ratio to GHI, DNI, and DHI using NSRDB data and utility modules like DISC""" - def __init__(self, sup3r_fps, nsrdb_fp, t_slice=slice(None), tz=-6, - agg_factor=1, nn_threshold=0.5, cloud_threshold=0.99): + def __init__( + self, + sup3r_fps, + nsrdb_fp, + t_slice=slice(None), + tz=-6, + agg_factor=1, + nn_threshold=0.5, + cloud_threshold=0.99, + ): """ Parameters ---------- @@ -81,12 +89,21 @@ def __init__(self, sup3r_fps, nsrdb_fp, t_slice=slice(None), tz=-6, if isinstance(self._sup3r_fps, str): self._sup3r_fps = [self._sup3r_fps] - logger.debug('Initializing solar module with sup3r files: {}' - .format([os.path.basename(fp) for fp in self._sup3r_fps])) - logger.debug('Initializing solar module with temporal slice: {}' - .format(self.t_slice)) - logger.debug('Initializing solar module with NSRDB source fp: {}' - .format(self._nsrdb_fp)) + logger.debug( + 'Initializing solar module with sup3r files: {}'.format( + [os.path.basename(fp) for fp in self._sup3r_fps] + ) + ) + logger.debug( + 'Initializing solar module with temporal slice: {}'.format( + self.t_slice + ) + ) + logger.debug( + 'Initializing solar module with NSRDB source fp: {}'.format( + self._nsrdb_fp + ) + ) self.gan_data = MultiTimeResource(self._sup3r_fps) self.nsrdb = Resource(self._nsrdb_fp) @@ -124,8 +141,10 @@ def preflight(self): ti_gan = self.gan_data.time_index ti_gan_1 = np.roll(ti_gan, 1) delta = pd.Series(ti_gan - ti_gan_1)[1:].mean().total_seconds() - msg = ('Its assumed that the sup3r GAN output solar data will be ' - 'hourly but received time index: {}'.format(ti_gan)) + msg = ( + 'Its assumed that the sup3r GAN output solar data will be ' + 'hourly but received time index: {}'.format(ti_gan) + ) assert delta == 3600, msg def close(self): @@ -203,12 +222,14 @@ def nsrdb_tslice(self): mask = doy_nsrdb.isin(doy_gan) if mask.sum() == 0: - msg = ('Time index intersection of the NSRDB time index and ' - 'sup3r GAN output has only {} common timesteps! ' - 'Something went wrong.\nNSRDB time index: \n{}\nSup3r ' - 'GAN output time index:\n{}' - .format(mask.sum(), self.nsrdb.time_index, - self.time_index)) + msg = ( + 'Time index intersection of the NSRDB time index and ' + 'sup3r GAN output has only {} common timesteps! ' + 'Something went wrong.\nNSRDB time index: \n{}\nSup3r ' + 'GAN output time index:\n{}'.format( + mask.sum(), self.nsrdb.time_index, self.time_index + ) + ) logger.error(msg) raise RuntimeError(msg) @@ -221,10 +242,13 @@ def nsrdb_tslice(self): step = int(3600 // delta) self._nsrdb_tslice = slice(t0, t1, step) - logger.debug('Found nsrdb_tslice {} with corresponding ' - 'time index:\n\t{}' - .format(self._nsrdb_tslice, - self.nsrdb.time_index[self._nsrdb_tslice])) + logger.debug( + 'Found nsrdb_tslice {} with corresponding ' + 'time index:\n\t{}'.format( + self._nsrdb_tslice, + self.nsrdb.time_index[self._nsrdb_tslice], + ) + ) return self._nsrdb_tslice @@ -245,7 +269,7 @@ def clearsky_ratio(self): # if tz is negative, roll to utc is positive, and the beginning of # the dataset is rolled over from the end and must be backfilled, # otherwise you can get seams - self._cs_ratio[:-self.tz, :] = self._cs_ratio[-self.tz, :] + self._cs_ratio[: -self.tz, :] = self._cs_ratio[-self.tz, :] # apply temporal slicing of source data, see docstring on t_slice # for more info @@ -278,8 +302,9 @@ def ghi(self): """ if self._ghi is None: logger.debug('Calculating GHI.') - self._ghi = (self.get_nsrdb_data('clearsky_ghi') - * self.clearsky_ratio) + self._ghi = ( + self.get_nsrdb_data('clearsky_ghi') * self.clearsky_ratio + ) self._ghi[:, self.out_of_bounds] = 0 return self._ghi @@ -299,8 +324,9 @@ def dni(self): self._dni = self.get_nsrdb_data('clearsky_dni') pressure = self.get_nsrdb_data('surface_pressure') doy = self.time_index.day_of_year.values - cloudy_dni = disc(self.ghi, self.solar_zenith_angle, doy, - pressure=pressure) + cloudy_dni = disc( + self.ghi, self.solar_zenith_angle, doy, pressure=pressure + ) cloudy_dni = np.minimum(self._dni, cloudy_dni) self._dni[self.cloud_mask] = cloudy_dni[self.cloud_mask] self._dni = dark_night(self._dni, self.solar_zenith_angle) @@ -319,8 +345,9 @@ def dhi(self): """ if self._dhi is None: logger.debug('Calculating DHI.') - self._dhi, self._dni = calc_dhi(self.dni, self.ghi, - self.solar_zenith_angle) + self._dhi, self._dni = calc_dhi( + self.dni, self.ghi, self.solar_zenith_angle + ) self._dhi = dark_night(self._dhi, self.solar_zenith_angle) self._dhi[:, self.out_of_bounds] = 0 return self._dhi @@ -410,8 +437,9 @@ def get_sup3r_fps(fp_pattern, ignore=None): all_fps = [fp for fp in glob.glob(fp_pattern) if fp.endswith('.h5')] if ignore is not None: - all_fps = [fp for fp in all_fps - if ignore not in os.path.basename(fp)] + all_fps = [ + fp for fp in all_fps if ignore not in os.path.basename(fp) + ] all_fps = sorted(all_fps) @@ -419,10 +447,12 @@ def get_sup3r_fps(fp_pattern, ignore=None): source_fn_base = os.path.basename(all_fps[0]).replace('.h5', '') source_fn_base = '_'.join(source_fn_base.split('_')[:-2]) - all_id_spatial = [fp.replace('.h5', '').split('_')[-1] - for fp in all_fps] - all_id_temporal = [fp.replace('.h5', '').split('_')[-2] - for fp in all_fps] + all_id_spatial = [ + fp.replace('.h5', '').split('_')[-1] for fp in all_fps + ] + all_id_temporal = [ + fp.replace('.h5', '').split('_')[-2] for fp in all_fps + ] all_id_spatial = sorted(list(set(all_id_spatial))) all_id_temporal = sorted(list(set(all_id_temporal))) @@ -473,7 +503,7 @@ def get_node_cmd(cls, config): run Solar.run_temporal_chunk() on a single node. """ import_str = 'import time;\n' - import_str += 'from reV.pipeline.status import Status;\n' + import_str += 'from sup3r.pipeline import Status;\n' import_str += 'from rex import init_logger;\n' import_str += f'from sup3r.solar import {cls.__name__};\n' @@ -481,15 +511,17 @@ def get_node_cmd(cls, config): log_file = config.get('log_file', None) log_level = config.get('log_level', 'INFO') - log_arg_str = (f'"sup3r", log_level="{log_level}"') + log_arg_str = f'"sup3r", log_level="{log_level}"' if log_file is not None: log_arg_str += f', log_file="{log_file}"' - cmd = (f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"{fun_str};\n" - "t_elap = time.time() - t0;\n") + cmd = ( + f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"{fun_str};\n" + "t_elap = time.time() - t0;\n" + ) job_name = config.get('job_name', None) if job_name is not None: @@ -499,15 +531,17 @@ def get_node_cmd(cls, config): status_file_arg_str += f'job_name="{job_name}", ' status_file_arg_str += 'attrs=job_attrs' - cmd += ('job_attrs = {};\n'.format(json.dumps(config) - .replace("null", "None") - .replace("false", "False") - .replace("true", "True"))) + cmd += 'job_attrs = {};\n'.format( + json.dumps(config) + .replace("null", "None") + .replace("false", "False") + .replace("true", "True") + ) cmd += 'job_attrs.update({"job_status": "successful"});\n' cmd += 'job_attrs.update({"time": t_elap});\n' cmd += f'Status.make_job_file({status_file_arg_str})' - cmd += (";\'\n") + cmd += ";\'\n" return cmd.replace('\\', '/') @@ -535,15 +569,21 @@ def write(self, fp_out, features=('ghi', 'dni', 'dhi')): attrs = H5_ATTRS[feature] arr = getattr(self, feature, None) if arr is None: - msg = ('Feature "{}" was not available from Solar ' - 'module class.'.format(feature)) + msg = ( + 'Feature "{}" was not available from Solar ' + 'module class.'.format(feature) + ) logger.error(msg) raise AttributeError(msg) - fh.add_dataset(fp_out, feature, arr, - dtype=attrs['dtype'], - attrs=attrs, - chunks=attrs['chunks']) + fh.add_dataset( + fp_out, + feature, + arr, + dtype=attrs['dtype'], + attrs=attrs, + chunks=attrs['chunks'], + ) logger.info(f'Added "{feature}" to output file.') run_attrs = self.gan_data.h5[self._sup3r_fps[0]].global_attrs run_attrs['nsrdb_source'] = self._nsrdb_fp @@ -552,11 +592,18 @@ def write(self, fp_out, features=('ghi', 'dni', 'dhi')): logger.info(f'Finished writing file: {fp_out}') @classmethod - def run_temporal_chunk(cls, fp_pattern, nsrdb_fp, - fp_out_suffix='irradiance', tz=-6, agg_factor=1, - nn_threshold=0.5, cloud_threshold=0.99, - features=('ghi', 'dni', 'dhi'), - temporal_id=None): + def run_temporal_chunk( + cls, + fp_pattern, + nsrdb_fp, + fp_out_suffix='irradiance', + tz=-6, + agg_factor=1, + nn_threshold=0.5, + cloud_threshold=0.99, + features=('ghi', 'dni', 'dhi'), + temporal_id=None, + ): """Run the solar module on all spatial chunks for a single temporal chunk corresponding to the fp_pattern. This typically gets run from the CLI. @@ -609,22 +656,36 @@ def run_temporal_chunk(cls, fp_pattern, nsrdb_fp, fp_sets, t_slices, temporal_ids, _, target_fps = temp if temporal_id is not None: - fp_sets = [fp_set for i, fp_set in enumerate(fp_sets) - if temporal_ids[i] == temporal_id] - t_slices = [t_slice for i, t_slice in enumerate(t_slices) - if temporal_ids[i] == temporal_id] - target_fps = [target_fp for i, target_fp in enumerate(target_fps) - if temporal_ids[i] == temporal_id] + fp_sets = [ + fp_set + for i, fp_set in enumerate(fp_sets) + if temporal_ids[i] == temporal_id + ] + t_slices = [ + t_slice + for i, t_slice in enumerate(t_slices) + if temporal_ids[i] == temporal_id + ] + target_fps = [ + target_fp + for i, target_fp in enumerate(target_fps) + if temporal_ids[i] == temporal_id + ] zip_iter = zip(fp_sets, t_slices, target_fps) for i, (fp_set, t_slice, fp_target) in enumerate(zip_iter): fp_out = fp_target.replace('.h5', f'_{fp_out_suffix}.h5') - logger.info('Running temporal index {} out of {}.' - .format(i + 1, len(fp_sets))) - kwargs = dict(t_slice=t_slice, - tz=tz, - agg_factor=agg_factor, - nn_threshold=nn_threshold, - cloud_threshold=cloud_threshold) + logger.info( + 'Running temporal index {} out of {}.'.format( + i + 1, len(fp_sets) + ) + ) + kwargs = dict( + t_slice=t_slice, + tz=tz, + agg_factor=agg_factor, + nn_threshold=nn_threshold, + cloud_threshold=cloud_threshold, + ) with Solar(fp_set, nsrdb_fp, **kwargs) as solar: solar.write(fp_out, features=features) diff --git a/sup3r/utilities/cli.py b/sup3r/utilities/cli.py index a40cc5b3a..177d778e3 100644 --- a/sup3r/utilities/cli.py +++ b/sup3r/utilities/cli.py @@ -1,24 +1,20 @@ - # -*- coding: utf-8 -*- """ Sup3r base CLI class. """ -import click +import json import logging import os -import json - -from reV.pipeline.status import Status +import click from rex.utilities.execution import SubprocessManager from rex.utilities.hpc import SLURM from rex.utilities.loggers import init_mult -from sup3r.version import __version__ +from sup3r.pipeline import Status from sup3r.pipeline.config import BaseConfig from sup3r.utilities import ModuleName - logger = logging.getLogger(__name__) @@ -44,8 +40,9 @@ def from_config(cls, module_name, module_class, ctx, config_file, verbose): verbose : bool Whether to run in verbose mode. """ - config = cls.from_config_preflight(module_name, ctx, config_file, - verbose) + config = cls.from_config_preflight( + module_name, ctx, config_file, verbose + ) exec_kwargs = config.get('execution_control', {}) hardware_option = exec_kwargs.pop('option', 'local') @@ -94,7 +91,7 @@ def from_config_preflight(cls, module_name, ctx, config_file, verbose): log_file = config.get('log_file', None) log_pattern = config.get('log_pattern', None) config_verbose = config.get('log_level', 'INFO') - config_verbose = (config_verbose == 'DEBUG') + config_verbose = config_verbose == 'DEBUG' verbose = any([verbose, config_verbose, ctx.obj['VERBOSE']]) exec_kwargs = config.get('execution_control', {}) hardware_option = exec_kwargs.get('option', 'local') @@ -102,8 +99,12 @@ def from_config_preflight(cls, module_name, ctx, config_file, verbose): log_dir = log_file or log_pattern log_dir = log_dir if log_dir is None else os.path.dirname(log_dir) - init_mult(f'sup3r_{module_name.replace("-", "_")}', - log_dir, modules=[__name__, 'sup3r'], verbose=verbose) + init_mult( + f'sup3r_{module_name.replace("-", "_")}', + log_dir, + modules=[__name__, 'sup3r'], + verbose=verbose, + ) if log_pattern is not None: os.makedirs(os.path.dirname(log_pattern), exist_ok=True) @@ -130,14 +131,24 @@ def from_config_preflight(cls, module_name, ctx, config_file, verbose): @classmethod def check_module_name(cls, module_name): """Make sure module_name is a valid member of the ModuleName class""" - msg = ('Module name must be in ModuleName class. Received ' - f'{module_name}.') + msg = ( + 'Module name must be in ModuleName class. Received ' + f'{module_name}.' + ) assert module_name in ModuleName, msg @classmethod - def kickoff_slurm_job(cls, module_name, ctx, cmd, alloc='sup3r', - memory=None, walltime=4, feature=None, - stdout_path='./stdout/'): + def kickoff_slurm_job( + cls, + module_name, + ctx, + cmd, + alloc='sup3r', + memory=None, + walltime=4, + feature=None, + stdout_path='./stdout/', + ): """Run sup3r module on HPC via SLURM job submission. Parameters @@ -170,37 +181,53 @@ def kickoff_slurm_job(cls, module_name, ctx, cmd, alloc='sup3r', slurm_manager = SLURM() ctx.obj['SLURM_MANAGER'] = slurm_manager - status = Status.retrieve_job_status(out_dir, - module=module_name, - job_name=name, - hardware='slurm', - subprocess_manager=slurm_manager) + status = Status.retrieve_job_status( + out_dir, + module=module_name, + job_name=name, + hardware='slurm', + subprocess_manager=slurm_manager, + ) msg = f'sup3r {module_name} CLI failed to submit jobs!' if status == 'successful': - msg = (f'Job "{name}" is successful in status json found in ' - f'"{out_dir}", not re-running.') + msg = ( + f'Job "{name}" is successful in status json found in ' + f'"{out_dir}", not re-running.' + ) elif 'fail' not in str(status).lower() and status is not None: - msg = (f'Job "{name}" was found with status "{status}", not ' - 'resubmitting') + msg = ( + f'Job "{name}" was found with status "{status}", not ' + 'resubmitting' + ) else: - logger.info(f'Running sup3r {module_name} on SLURM with node ' - f'name "{name}".') - out = slurm_manager.sbatch(cmd, - alloc=alloc, - memory=memory, - walltime=walltime, - feature=feature, - name=name, - stdout_path=stdout_path)[0] + logger.info( + f'Running sup3r {module_name} on SLURM with node ' + f'name "{name}".' + ) + out = slurm_manager.sbatch( + cmd, + alloc=alloc, + memory=memory, + walltime=walltime, + feature=feature, + name=name, + stdout_path=stdout_path, + )[0] if out: - msg = (f'Kicked off sup3r {module_name} job "{name}" ' - f'(SLURM jobid #{out}).') + msg = ( + f'Kicked off sup3r {module_name} job "{name}" ' + f'(SLURM jobid #{out}).' + ) # add job to sup3r status file. - Status.add_job(out_dir, module=module_name, - job_name=name, replace=True, - job_attrs={'job_id': out, 'hardware': 'slurm'}) + Status.add_job( + out_dir, + module=module_name, + job_name=name, + replace=True, + job_attrs={'job_id': out, 'hardware': 'slurm'}, + ) click.echo(msg) logger.info(msg) @@ -224,23 +251,30 @@ def kickoff_local_job(cls, module_name, ctx, cmd): name = ctx.obj['NAME'] out_dir = ctx.obj['OUT_DIR'] subprocess_manager = SubprocessManager - status = Status.retrieve_job_status(out_dir, - module=module_name, - job_name=name) + status = Status.retrieve_job_status( + out_dir, module=module_name, job_name=name + ) msg = f'sup3r {module_name} CLI failed to submit jobs!' if status == 'successful': - msg = (f'Job "{name}" is successful in status json found in ' - f'"{out_dir}", not re-running.') + msg = ( + f'Job "{name}" is successful in status json found in ' + f'"{out_dir}", not re-running.' + ) elif 'fail' not in str(status).lower() and status is not None: - msg = (f'Job "{name}" was found with status "{status}", not ' - 'resubmitting') + msg = ( + f'Job "{name}" was found with status "{status}", not ' + 'resubmitting' + ) else: - logger.info(f'Running sup3r {module_name} locally with job ' - f'name "{name}".') - Status.add_job(out_dir, module=module_name, job_name=name, - replace=True) + logger.info( + f'Running sup3r {module_name} locally with job ' + f'name "{name}".' + ) + Status.add_job( + out_dir, module=module_name, job_name=name, replace=True + ) subprocess_manager.submit(cmd) - msg = (f'Completed sup3r {module_name} job "{name}".') + msg = f'Completed sup3r {module_name} job "{name}".' click.echo(msg) logger.info(msg) @@ -273,10 +307,12 @@ def add_status_cmd(cls, config, module_name, cmd): status_file_arg_str += f'job_name="{job_name}", ' status_file_arg_str += 'attrs=job_attrs' - cmd += ('job_attrs = {};\n'.format(json.dumps(config) - .replace("null", "None") - .replace("false", "False") - .replace("true", "True"))) + cmd += 'job_attrs = {};\n'.format( + json.dumps(config) + .replace("null", "None") + .replace("false", "False") + .replace("true", "True") + ) cmd += 'job_attrs.update({"job_status": "successful"});\n' cmd += 'job_attrs.update({"time": t_elap});\n' cmd += f"Status.make_job_file({status_file_arg_str})" diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index d2f067fcc..de0080ebf 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -39,14 +39,30 @@ class EraDownloader: """Class to handle ERA5 downloading, variable renaming, file combination, and interpolation.""" - msg = ( - 'To download ERA5 data you need to have a ~/.cdsapirc file ' - 'with a valid url and api key. Follow the instructions here: ' - 'https://cds.climate.copernicus.eu/api-how-to' - ) + msg = ('To download ERA5 data you need to have a ~/.cdsapirc file ' + 'with a valid url and api key. Follow the instructions here: ' + 'https://cds.climate.copernicus.eu/api-how-to') req_file = os.path.join(os.path.expanduser('~'), '.cdsapirc') assert os.path.exists(req_file), msg + VALID_VARIABLES: ClassVar[list] = [ + 'u', + 'v', + 'pressure', + 'temperature', + 'relative_humidity', + 'specific_humidity', + 'total_precipitation', + ] + + KEEP_VARIABLES: ClassVar[list] = [ + 'orog', + 'time', + 'latitude', + 'longitude', + ] + KEEP_VARIABLES += [f'{v}_' for v in VALID_VARIABLES] + DEFAULT_RENAMED_VARS: ClassVar[list] = [ 'zg', 'orog', @@ -56,6 +72,8 @@ class EraDownloader: 'v_10m', 'u_100m', 'v_100m', + 'temperature', + 'pressure', ] DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ '10m_u_component_of_wind', @@ -64,6 +82,11 @@ class EraDownloader: '100m_v_component_of_wind', 'u_component_of_wind', 'v_component_of_wind', + '2m_temperature', + 'temperature', + 'surface_pressure', + 'relative_humidity', + 'total_precipitation', ] SFC_VARS: ClassVar[list] = [ @@ -74,12 +97,15 @@ class EraDownloader: 'surface_pressure', '2m_temperature', 'geopotential', + 'total_precipitation', ] LEVEL_VARS: ClassVar[list] = [ 'u_component_of_wind', 'v_component_of_wind', 'geopotential', 'temperature', + 'relative_humidity', + 'specific_humidity', ] NAME_MAP: ClassVar[dict] = { 'u10': 'u_10m', @@ -91,21 +117,22 @@ class EraDownloader: 'u': 'u', 'v': 'v', 'sp': 'pressure_0m', + 'r': 'relative_humidity', + 'q': 'specific_humidity', + 'tp': 'total_precip', } - def __init__( - self, - year, - month, - area, - levels, - combined_out_pattern, - interp_out_pattern=None, - run_interp=True, - overwrite=False, - required_shape=None, - variables=None, - ): + def __init__(self, + year, + month, + area, + levels, + combined_out_pattern, + interp_out_pattern=None, + run_interp=True, + overwrite=False, + required_shape=None, + variables=None): """Initialize the class. Parameters @@ -144,53 +171,73 @@ def __init__( self.run_interp = run_interp self.overwrite = overwrite self.combined_out_pattern = combined_out_pattern - self.variables = ( - variables if variables is not None else self.DEFAULT_DOWNLOAD_VARS - ) - self.days = [ - str(n).zfill(2) - for n in np.arange(1, monthrange(year, month)[1] + 1) - ] + self.interp_out_pattern = interp_out_pattern + self._interp_file = None + self._combined_file = None + self._variables = variables self.hours = [str(n).zfill(2) + ":00" for n in range(0, 24)] + self.sfc_file_variables = ['geopotential'] + self.level_file_variables = ['geopotential'] - if required_shape is None or len(required_shape) == 3: - self.required_shape = required_shape - elif len(required_shape) == 2 and len(levels) != required_shape[0]: - self.required_shape = (len(levels), *required_shape) - else: - msg = f'Received weird required_shape: {required_shape}.' - logger.error(msg) - raise OSError(msg) + self.shape_check(required_shape, levels) + self.check_good_vars(self.variables) + self.prep_var_lists(self.variables) - self.interp_file = None - if interp_out_pattern is not None and run_interp: - self.interp_file = interp_out_pattern.format( - year=year, month=str(month).zfill(2) - ) - os.makedirs(os.path.dirname(self.interp_file), exist_ok=True) + msg = ('Initialized EraDownloader with: ' + f'year={self.year}, month={self.month}, area={self.area}, ' + f'levels={self.levels}, variables={self.variables}') + logger.info(msg) - self.combined_file = combined_out_pattern.format( - year=year, month=str(month).zfill(2) - ) - os.makedirs(os.path.dirname(self.combined_file), exist_ok=True) + @property + def variables(self): + """Get list of requested variables""" + if self._variables is None: + self._variables = self.VALID_VARIABLES + return self._variables + + @property + def days(self): + """Get list of days for the requested month""" + return [ + str(n).zfill(2) + for n in np.arange(1, + monthrange(self.year, self.month)[1] + 1) + ] + + @property + def interp_file(self): + """Get name of file with interpolated variables""" + if self._interp_file is None: + if self.interp_out_pattern is not None and self.run_interp: + self._interp_file = self.interp_out_pattern.format( + year=self.year, month=str(self.month).zfill(2)) + os.makedirs(os.path.dirname(self._interp_file), exist_ok=True) + return self._interp_file + + @property + def combined_file(self): + """Get name of file from combined surface and level files""" + if self._combined_file is None: + self._combined_file = self.combined_out_pattern.format( + year=self.year, month=str(self.month).zfill(2)) + os.makedirs(os.path.dirname(self._combined_file), exist_ok=True) + return self._combined_file + + @property + def surface_file(self): + """Get name of file with variables from single level download""" basedir = os.path.dirname(self.combined_file) - self.surface_file = os.path.join( - basedir, f'sfc_{year}_{str(month).zfill(2)}.nc' - ) - self.level_file = os.path.join( - basedir, f'levels_{year}_{str(month).zfill(2)}.nc' - ) - self.sfc_file_variables = [] - self.level_file_variables = [] - self.check_good_vars(self.variables) - self.prep_var_lists(variables) + basename = f'sfc_{"_".join(self.variables)}_{self.year}_' + basename += f'{str(self.month).zfill(2)}.nc' + return os.path.join(basedir, basename) - msg = ( - 'Initialized EraDownloader with: ' - f'year={self.year}, month={self.month}, area={self.area}, ' - f'levels={self.levels}, variables={self.variables}' - ) - logger.info(msg) + @property + def level_file(self): + """Get name of file with variables from pressure level download""" + basedir = os.path.dirname(self.combined_file) + basename = f'levels_{"_".join(self.variables)}_{self.year}_' + basename += f'{str(self.month).zfill(2)}.nc' + return os.path.join(basedir, basename) @classmethod def init_dims(cls, old_ds, new_ds, dims): @@ -224,22 +271,29 @@ def get_tmp_file(cls, file): tmp_file = file.replace(".nc", "_tmp.nc") return tmp_file + def shape_check(self, required_shape, levels): + """Check given required shape""" + if required_shape is None or len(required_shape) == 3: + self.required_shape = required_shape + elif len(required_shape) == 2 and len(levels) != required_shape[0]: + self.required_shape = (len(levels), *required_shape) + else: + msg = f'Received weird required_shape: {required_shape}.' + logger.error(msg) + raise OSError(msg) + def check_good_vars(self, variables): """Make sure requested variables are valid. Parameters ---------- variables : list - List of variables to download. Can be any of ['u', 'v', 'pressure', - temperature'] + List of variables to download. Can be any of VALID_VARIABLES """ - valid_vars = ['u', 'v', 'pressure', 'temperature'] - good = all(var in valid_vars for var in variables) + good = all(var in self.VALID_VARIABLES for var in variables) if not good: - msg = ( - f'Received variables {variables} not in valid variables ' - f'list {valid_vars}' - ) + msg = (f'Received variables {variables} not in valid variables ' + f'list {self.VALID_VARIABLES}') logger.error(msg) raise OSError(msg) @@ -262,9 +316,10 @@ def prep_var_lists(self, variables): variables.""" variables = self._prep_var_lists(variables) for var in variables: - if var in self.SFC_VARS: + if var in self.SFC_VARS and var not in self.sfc_file_variables: self.sfc_file_variables.append(var) - elif var in self.LEVEL_VARS: + elif (var in self.LEVEL_VARS + and var not in self.level_file_variables): self.level_file_variables.append(var) else: msg = f'Requested {var} is not available for download.' @@ -274,14 +329,11 @@ def prep_var_lists(self, variables): def download_process_combine(self): """Run the download routine.""" sfc_check = len(self.sfc_file_variables) > 0 - level_check = ( - len(self.level_file_variables) > 0 and self.levels is not None - ) + level_check = (len(self.level_file_variables) > 0 + and self.levels is not None) if self.level_file_variables: - msg = ( - f'{self.level_file_variables} requested but no levels' - ' were provided.' - ) + msg = (f'{self.level_file_variables} requested but no levels' + ' were provided.') if self.levels is None: logger.warning(msg) warn(msg) @@ -289,22 +341,17 @@ def download_process_combine(self): self.download_surface_file() if level_check: self.download_levels_file() - if sfc_check and level_check: + if sfc_check or level_check: self.process_and_combine() def download_levels_file(self): """Download file with requested pressure levels""" if not os.path.exists(self.level_file) or self.overwrite: - if 'geopotential' not in self.level_file_variables: - self.level_file_variables.append('geopotential') - msg = ( - f'Downloading {self.level_file_variables} to ' - f'{self.level_file}.' - ) + msg = (f'Downloading {self.level_file_variables} to ' + f'{self.level_file}.') logger.info(msg) CDS_API_CLIENT.retrieve( - 'reanalysis-era5-pressure-levels', - { + 'reanalysis-era5-pressure-levels', { 'product_type': 'reanalysis', 'format': 'netcdf', 'variable': self.level_file_variables, @@ -314,25 +361,18 @@ def download_levels_file(self): 'day': self.days, 'time': self.hours, 'area': self.area, - }, - self.level_file, - ) + }, self.level_file) else: logger.info(f'File already exists: {self.level_file}.') def download_surface_file(self): """Download surface file""" if not os.path.exists(self.surface_file) or self.overwrite: - if 'geopotential' not in self.sfc_file_variables: - self.sfc_file_variables.append('geopotential') - msg = ( - f'Downloading {self.sfc_file_variables} to ' - f'{self.surface_file}.' - ) + msg = (f'Downloading {self.sfc_file_variables} to ' + f'{self.surface_file}.') logger.info(msg) CDS_API_CLIENT.retrieve( - 'reanalysis-era5-single-levels', - { + 'reanalysis-era5-single-levels', { 'product_type': 'reanalysis', 'format': 'netcdf', 'variable': self.sfc_file_variables, @@ -341,9 +381,7 @@ def download_surface_file(self): 'day': self.days, 'time': self.hours, 'area': self.area, - }, - self.surface_file, - ) + }, self.surface_file) else: logger.info(f'File already exists: {self.surface_file}.') @@ -360,10 +398,8 @@ def process_surface_file(self): ds = self.map_vars(old_ds, ds) os.system(f'mv {tmp_file} {self.surface_file}') - logger.info( - f'Finished processing {self.surface_file}. Moved ' - f'{tmp_file} to {self.surface_file}.' - ) + logger.info(f'Finished processing {self.surface_file}. Moved ' + f'{tmp_file} to {self.surface_file}.') def map_vars(self, old_ds, ds): """Map variables from old dataset to new dataset @@ -414,9 +450,9 @@ def convert_z(self, standard_name, long_name, old_ds, ds): Dataset() object for new file with new height variable written. """ - _ = ds.createVariable( - standard_name, np.float32, dimensions=old_ds['z'].dimensions - ) + _ = ds.createVariable(standard_name, + np.float32, + dimensions=old_ds['z'].dimensions) ds.variables[standard_name][:] = old_ds['z'][:] / 9.81 ds.variables[standard_name].long_name = long_name ds.variables[standard_name].standard_name = 'zg' @@ -440,36 +476,41 @@ def process_level_file(self): tmp = np.zeros(ds.variables['zg'].shape) for i in range(tmp.shape[1]): tmp[:, i, :, :] = ds.variables['level'][i] * 100 - _ = ds.createVariable( - 'pressure', np.float32, dimensions=dims - ) + _ = ds.createVariable('pressure', + np.float32, + dimensions=dims) ds.variables['pressure'][:] = tmp[...] ds.variables['pressure'].long_name = 'Pressure' ds.variables['pressure'].units = 'Pa' os.system(f'mv {tmp_file} {self.level_file}') - logger.info( - f'Finished processing {self.level_file}. Moved ' - f'{tmp_file} to {self.level_file}.' - ) + logger.info(f'Finished processing {self.level_file}. Moved ' + f'{tmp_file} to {self.level_file}.') def process_and_combine(self): """Process variables and combine.""" if not os.path.exists(self.combined_file) or self.overwrite: - logger.info(f'Processing {self.level_file}.') - self.process_level_file() - logger.info(f'Processing {self.surface_file}.') - self.process_surface_file() - logger.info( - f'Combining {self.level_file} and {self.surface_file} ' - f'to {self.combined_file}.' - ) - with xr.open_mfdataset([self.level_file, self.surface_file]) as ds: + files = [] + if os.path.exists(self.level_file): + logger.info(f'Processing {self.level_file}.') + self.process_level_file() + files.append(self.level_file) + if os.path.exists(self.surface_file): + logger.info(f'Processing {self.surface_file}.') + self.process_surface_file() + files.append(self.surface_file) + + logger.info(f'Combining {files} and {self.surface_file} ' + f'to {self.combined_file}.') + with xr.open_mfdataset(files) as ds: ds.to_netcdf(self.combined_file) logger.info(f'Finished writing {self.combined_file}') - os.remove(self.level_file) - os.remove(self.surface_file) + + if os.path.exists(self.level_file): + os.remove(self.level_file) + if os.path.exists(self.surface_file): + os.remove(self.surface_file) def good_file(self, file, required_shape): """Check if file has the required shape and variables. @@ -486,12 +527,10 @@ def good_file(self, file, required_shape): bool Whether or not data has required shape and variables. """ - out = self.check_single_file( - file, - check_nans=False, - check_heights=False, - required_shape=required_shape, - ) + out = self.check_single_file(file, + check_nans=False, + check_heights=False, + required_shape=required_shape) good_vars, good_shape, _, _ = out check = good_vars and good_shape return check @@ -512,31 +551,26 @@ def check_existing_files(self): os.remove(self.level_file) if os.path.exists(self.surface_file): os.remove(self.surface_file) - logger.info( - f'{self.combined_file} already exists and ' - f'overwrite={self.overwrite}. Skipping.' - ) + logger.info(f'{self.combined_file} already exists and ' + f'overwrite={self.overwrite}. Skipping.') except Exception as e: logger.info(f'Something wrong with {self.combined_file}. {e}') if os.path.exists(self.combined_file): os.remove(self.combined_file) check = self.interp_file is not None and os.path.exists( - self.interp_file - ) + self.interp_file) if check: os.remove(self.interp_file) def run_interpolation(self, max_workers=None, **kwargs): """Run interpolation to get final final. Runs log interpolation up to max_log_height (usually 100m) and linear interpolation above this.""" - LogLinInterpolator.run( - infile=self.combined_file, - outfile=self.interp_file, - max_workers=max_workers, - variables=self.variables, - overwrite=self.overwrite, - **kwargs, - ) + LogLinInterpolator.run(infile=self.combined_file, + outfile=self.interp_file, + max_workers=max_workers, + variables=self.variables, + overwrite=self.overwrite, + **kwargs) def get_monthly_file(self, interp_workers=None, **interp_kwargs): """Download level and surface files, process variables, and combine @@ -579,29 +613,17 @@ def all_months_exist(cls, year, file_pattern): """ return all( os.path.exists( - file_pattern.format(year=year, month=str(month).zfill(2)) - ) - for month in range(1, 13) - ) + file_pattern.format(year=year, month=str(month).zfill(2))) + for month in range(1, 13)) @classmethod def already_pruned(cls, infile): """Check if file has been pruned already.""" - keep_vars = ( - 'u_', - 'v_', - 'pressure_', - 'temperature_', - 'orog', - 'time', - 'latitude', - 'longitude', - ) pruned = True with Dataset(infile, 'r') as ds: for var in ds.variables: - if not any(name in var for name in keep_vars): + if not any(name in var for name in cls.KEEP_VARIABLES): logger.info(f'Pruning {var} in {infile}.') pruned = False return pruned @@ -612,19 +634,16 @@ def prune_output(cls, infile): logger.info(f'Pruning {infile}.') tmp_file = cls.get_tmp_file(infile) - keep_vars = ('u_', 'v_', 'pressure_', 'temperature_', 'orog') with Dataset(infile, 'r') as old_ds: with Dataset(tmp_file, 'w') as new_ds: - new_ds = cls.init_dims( - old_ds, new_ds, ('time', 'latitude', 'longitude') - ) + new_ds = cls.init_dims(old_ds, new_ds, + ('time', 'latitude', 'longitude')) for var in old_ds.variables: - if any(name in var for name in keep_vars): + if any(name in var for name in cls.KEEP_VARIABLES): old_var = old_ds[var] vals = old_var[:] _ = new_ds.createVariable( - var, old_var.dtype, dimensions=old_var.dimensions - ) + var, old_var.dtype, dimensions=old_var.dimensions) new_ds[var][:] = vals if hasattr(old_var, 'units'): new_ds[var].units = old_var.units @@ -634,27 +653,23 @@ def prune_output(cls, infile): if hasattr(old_var, 'long_name'): new_ds[var].long_name = old_var.long_name os.system(f'mv {tmp_file} {infile}') - logger.info( - f'Finished pruning variables in {infile}. Moved ' - f'{tmp_file} to {infile}.' - ) + logger.info(f'Finished pruning variables in {infile}. Moved ' + f'{tmp_file} to {infile}.') @classmethod - def run_month( - cls, - year, - month, - area, - levels, - combined_out_pattern, - interp_out_pattern=None, - run_interp=True, - overwrite=False, - required_shape=None, - interp_workers=None, - variables=None, - **interp_kwargs, - ): + def run_month(cls, + year, + month, + area, + levels, + combined_out_pattern, + interp_out_pattern=None, + run_interp=True, + overwrite=False, + required_shape=None, + interp_workers=None, + variables=None, + **interp_kwargs): """Run routine for all months in the requested year. Parameters @@ -690,40 +705,35 @@ def run_month( **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ - downloader = cls( - year=year, - month=month, - area=area, - levels=levels, - combined_out_pattern=combined_out_pattern, - interp_out_pattern=interp_out_pattern, - run_interp=run_interp, - overwrite=overwrite, - required_shape=required_shape, - variables=variables, - ) - downloader.get_monthly_file( - interp_workers=interp_workers, **interp_kwargs - ) + downloader = cls(year=year, + month=month, + area=area, + levels=levels, + combined_out_pattern=combined_out_pattern, + interp_out_pattern=interp_out_pattern, + run_interp=run_interp, + overwrite=overwrite, + required_shape=required_shape, + variables=variables) + downloader.get_monthly_file(interp_workers=interp_workers, + **interp_kwargs) @classmethod - def run_year( - cls, - year, - area, - levels, - combined_out_pattern, - combined_yearly_file, - interp_out_pattern=None, - interp_yearly_file=None, - run_interp=True, - overwrite=False, - required_shape=None, - max_workers=None, - interp_workers=None, - variables=None, - **interp_kwargs, - ): + def run_year(cls, + year, + area, + levels, + combined_out_pattern, + combined_yearly_file, + interp_out_pattern=None, + interp_yearly_file=None, + run_interp=True, + overwrite=False, + required_shape=None, + max_workers=None, + interp_workers=None, + variables=None, + **interp_kwargs): """Run routine for all months in the requested year. Parameters @@ -766,20 +776,18 @@ def run_year( """ if max_workers == 1: for month in range(1, 13): - cls.run_month( - year=year, - month=month, - area=area, - levels=levels, - combined_out_pattern=combined_out_pattern, - interp_out_pattern=interp_out_pattern, - run_interp=run_interp, - overwrite=overwrite, - required_shape=required_shape, - interp_workers=interp_workers, - variables=variables, - **interp_kwargs, - ) + cls.run_month(year=year, + month=month, + area=area, + levels=levels, + combined_out_pattern=combined_out_pattern, + interp_out_pattern=interp_out_pattern, + run_interp=run_interp, + overwrite=overwrite, + required_shape=required_shape, + interp_workers=interp_workers, + variables=variables, + **interp_kwargs) else: futures = {} with ThreadPoolExecutor(max_workers=max_workers) as exe: @@ -797,20 +805,15 @@ def run_year( required_shape=required_shape, interp_workers=interp_workers, variables=variables, - **interp_kwargs, - ) + **interp_kwargs) futures[future] = {'year': year, 'month': month} - logger.info( - f'Submitted future for year {year} and month ' - f'{month}.' - ) + logger.info(f'Submitted future for year {year} and month ' + f'{month}.') for future in as_completed(futures): future.result() v = futures[future] - logger.info( - f'Finished future for year {v["year"]} and month ' - f'{v["month"]}.' - ) + logger.info(f'Finished future for year {v["year"]} and month ' + f'{v["month"]}.') cls.make_yearly_file(year, combined_out_pattern, combined_yearly_file) @@ -831,10 +834,8 @@ def make_yearly_file(cls, year, file_pattern, yearly_file): yearly_file : str Name of yearly file made from monthly files. """ - msg = ( - f'Not all monthly files with file_patten {file_pattern} for ' - f'year {year} exist.' - ) + msg = (f'Not all monthly files with file_patten {file_pattern} for ' + f'year {year} exist.') assert cls.all_months_exist(year, file_pattern), msg files = [ @@ -852,16 +853,14 @@ def make_yearly_file(cls, year, file_pattern, yearly_file): logger.info(f'{yearly_file} already exists.') @classmethod - def _check_single_file( - cls, - res, - var_list=None, - check_nans=True, - check_heights=True, - max_interp_height=200, - required_shape=None, - max_workers=10, - ): + def _check_single_file(cls, + res, + var_list=None, + check_nans=True, + check_heights=True, + max_interp_height=200, + required_shape=None, + max_workers=10): """Make sure given files include the given variables. Check for NaNs and required shape. @@ -902,21 +901,15 @@ def _check_single_file( *res['latitude'].shape, *res['longitude'].shape, ) - good_shape = ( - 'NA' if required_shape is None else (res_shape == required_shape) - ) - good_hgts = ( - 'NA' - if not check_heights - else cls.check_heights( - res, - max_interp_height=max_interp_height, - max_workers=max_workers, - ) - ) - nan_pct = ( - 'NA' if not check_nans else cls.get_nan_pct(res, var_list=var_list) - ) + good_shape = ('NA' if required_shape is None else + (res_shape == required_shape)) + good_hgts = ('NA' if not check_heights else cls.check_heights( + res, + max_interp_height=max_interp_height, + max_workers=max_workers, + )) + nan_pct = ('NA' if not check_nans else cls.get_nan_pct( + res, var_list=var_list)) if not good_vars: mask = np.array([var not in res for var in var_list]) @@ -948,23 +941,20 @@ def check_heights(cls, res, max_interp_height=200, max_workers=10): location and timestep """ gp = res['zg'].values - sfc_hgt = np.repeat( - res['orog'].values[:, np.newaxis, ...], gp.shape[1], axis=1 - ) + sfc_hgt = np.repeat(res['orog'].values[:, np.newaxis, ...], + gp.shape[1], + axis=1) heights = gp - sfc_hgt heights = heights.reshape(heights.shape[0], heights.shape[1], -1) checks = [] logger.info( - f'Checking heights with max_interp_height={max_interp_height}.' - ) + f'Checking heights with max_interp_height={max_interp_height}.') if max_workers == 1: for idt in range(heights.shape[0]): checks.append( cls._check_heights_single_ts( - heights[idt], max_interp_height=max_interp_height - ) - ) + heights[idt], max_interp_height=max_interp_height)) msg = f'Finished check for {idt + 1} of {heights.shape[0]}.' logger.debug(msg) else: @@ -977,17 +967,13 @@ def check_heights(cls, res, max_interp_height=200, max_workers=10): max_interp_height=max_interp_height, ) futures.append(future) - msg = ( - f'Submitted height check for {idt + 1} of ' - f'{heights.shape[0]}' - ) + msg = (f'Submitted height check for {idt + 1} of ' + f'{heights.shape[0]}') logger.info(msg) for i, future in enumerate(as_completed(futures)): checks.append(future.result()) - msg = ( - f'Finished height check for {i + 1} of ' - f'{heights.shape[0]}' - ) + msg = (f'Finished height check for {i + 1} of ' + f'{heights.shape[0]}') logger.info(msg) return all(checks) @@ -1044,16 +1030,14 @@ def get_nan_pct(cls, res, var_list=None): return 100 * nan_count / elem_count @classmethod - def check_single_file( - cls, - file, - var_list=None, - check_nans=True, - check_heights=True, - max_interp_height=200, - required_shape=None, - max_workers=10, - ): + def check_single_file(cls, + file, + var_list=None, + check_nans=True, + check_heights=True, + max_interp_height=200, + required_shape=None, + max_workers=10): """Make sure given files include the given variables. Check for NaNs and required shape. @@ -1093,9 +1077,7 @@ def check_single_file( good_shape = None good_vars = None good_hgts = None - var_list = ( - var_list if var_list is not None else cls.DEFAULT_RENAMED_VARS - ) + var_list = (var_list if var_list is not None else cls.VALID_VARIABLES) try: res = xr.open_dataset(file) except Exception as e: @@ -1105,30 +1087,26 @@ def check_single_file( good = False if good: - out = cls._check_single_file( - res, - var_list, - check_nans=check_nans, - check_heights=check_heights, - max_interp_height=max_interp_height, - required_shape=required_shape, - max_workers=max_workers, - ) + out = cls._check_single_file(res, + var_list, + check_nans=check_nans, + check_heights=check_heights, + max_interp_height=max_interp_height, + required_shape=required_shape, + max_workers=max_workers) good_vars, good_shape, good_hgts, nan_pct = out return good_vars, good_shape, good_hgts, nan_pct @classmethod - def run_files_checks( - cls, - file_pattern, - var_list=None, - required_shape=None, - check_nans=True, - check_heights=True, - max_interp_height=200, - max_workers=None, - height_check_workers=10, - ): + def run_files_checks(cls, + file_pattern, + var_list=None, + required_shape=None, + check_nans=True, + check_heights=True, + max_interp_height=200, + max_workers=None, + height_check_workers=10): """Make sure given files include the given variables. Check for NaNs and required shape. @@ -1165,9 +1143,9 @@ def run_files_checks( files = glob(file_pattern) else: files = file_pattern - df = pd.DataFrame( - columns=['file', 'good_vars', 'good_shape', 'good_hgts', 'nan_pct'] - ) + df = pd.DataFrame(columns=[ + 'file', 'good_vars', 'good_shape', 'good_hgts', 'nan_pct' + ]) df['file'] = [os.path.basename(file) for file in files] if max_workers == 1: for i, file in enumerate(files): @@ -1179,36 +1157,29 @@ def run_files_checks( check_heights=check_heights, max_interp_height=max_interp_height, max_workers=height_check_workers, - required_shape=required_shape, - ) - df.at[i, df.columns[1:]] = out + required_shape=required_shape) + df.loc[i, df.columns[1:]] = out logger.info(f'Finished checking {file}.') else: futures = {} with ThreadPoolExecutor(max_workers=max_workers) as exe: for i, file in enumerate(files): - future = exe.submit( - cls.check_single_file, - file=file, - var_list=var_list, - check_nans=check_nans, - check_heights=check_heights, - max_interp_height=max_interp_height, - max_workers=height_check_workers, - required_shape=required_shape, - ) - msg = ( - f'Submitted file check future for {file}. Future ' - f'{i + 1} of {len(files)}.' - ) + future = exe.submit(cls.check_single_file, + file=file, + var_list=var_list, + check_nans=check_nans, + check_heights=check_heights, + max_interp_height=max_interp_height, + max_workers=height_check_workers, + required_shape=required_shape) + msg = (f'Submitted file check future for {file}. Future ' + f'{i + 1} of {len(files)}.') logger.info(msg) futures[future] = i for i, future in enumerate(as_completed(futures)): out = future.result() - df.at[futures[future], df.columns[1:]] = out - msg = ( - f'Finished checking {df["file"].iloc[futures[future]]}.' - f' Future {i + 1} of {len(files)}.' - ) + df.loc[futures[future], df.columns[1:]] = out + msg = (f'Finished checking {df["file"].iloc[futures[future]]}.' + f' Future {i + 1} of {len(files)}.') logger.info(msg) return df diff --git a/sup3r/utilities/execution.py b/sup3r/utilities/execution.py index 96e3929b6..67b225605 100644 --- a/sup3r/utilities/execution.py +++ b/sup3r/utilities/execution.py @@ -4,10 +4,11 @@ @author: bbenton """ -import numpy as np import logging import os +import numpy as np + logger = logging.getLogger(__name__) @@ -15,8 +16,9 @@ class DistributedProcess: """High-level class with commonly used functionality for processes distributed across multiple nodes""" - def __init__(self, max_nodes=1, n_chunks=None, max_chunks=None, - incremental=False): + def __init__( + self, max_nodes=1, n_chunks=None, max_chunks=None, incremental=False + ): """ Parameters ---------- @@ -30,11 +32,14 @@ def __init__(self, max_nodes=1, n_chunks=None, max_chunks=None, incremental : bool Whether to skip previously run process chunks or to overwrite. """ - msg = ('For a distributed process either max_chunks or ' - 'max_chunks + n_chunks must be specified. Received ' - f'max_chunks={max_chunks}, n_chunks={n_chunks}.') + msg = ( + 'For a distributed process either max_chunks or ' + 'max_chunks + n_chunks must be specified. Received ' + f'max_chunks={max_chunks}, n_chunks={n_chunks}.' + ) assert max_chunks is not None, msg self._node_chunks = None + self._node_files = None self._n_chunks = n_chunks self._max_nodes = max_nodes self._max_chunks = max_chunks @@ -59,8 +64,9 @@ def node_finished(self, node_index): bool Whether all processes for the given node have finished """ - return all(self.chunk_finished(i) - for i in self.node_chunks[node_index]) + return all( + self.chunk_finished(i) for i in self.node_chunks[node_index] + ) # pylint: disable=E1136 def chunk_finished(self, chunk_index): @@ -80,8 +86,10 @@ def chunk_finished(self, chunk_index): """ out_file = self.out_files[chunk_index] if os.path.exists(out_file) and self.incremental: - logger.info('Not running chunk index {}, output file ' - 'exists: {}'.format(chunk_index, out_file)) + logger.info( + 'Not running chunk index {}, output file ' + 'exists: {}'.format(chunk_index, out_file) + ) return True return False @@ -119,10 +127,19 @@ def node_chunks(self): """Get the chunk indices for different nodes""" if self._node_chunks is None: n_chunks = min(self.max_nodes, self.chunks) - self._node_chunks = np.array_split(np.arange(self.chunks), - n_chunks) + self._node_chunks = np.array_split( + np.arange(self.chunks), n_chunks + ) return self._node_chunks + @property + def node_files(self): + """Get the file lists for different nodes""" + if self._node_files is None: + n_chunks = min(self.max_nodes, self.chunks) + self._node_files = np.array_split(self.out_files, n_chunks) + return self._node_files + @property def failed_chunks(self): """Check whether any processes have failed.""" diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index bd70d9697..7a7f51910 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -8,6 +8,7 @@ as_completed, ) from glob import glob +from typing import ClassVar from warnings import warn import numpy as np @@ -22,7 +23,6 @@ init_logger(__name__, log_level='DEBUG') init_logger('sup3r', log_level='DEBUG') - logger = logging.getLogger(__name__) @@ -31,11 +31,12 @@ class LogLinInterpolator: max_log_height, linearly interpolate components above max_log_height meters, and save to file""" - DEFAULT_OUTPUT_HEIGHTS = { + DEFAULT_OUTPUT_HEIGHTS: ClassVar[dict] = { 'u': [40, 80, 120, 160, 200], 'v': [40, 80, 120, 160, 200], 'temperature': [10, 40, 80, 100, 120, 160, 200], 'pressure': [0, 100, 200], + 'relative_humidity': [80, 100, 120], } def __init__( @@ -68,10 +69,8 @@ def __init__( self.infile = infile self.outfile = outfile - msg = ( - 'output_heights must be a dictionary with variables as keys ' - f'and lists of heights as values. Received: {output_heights}.' - ) + msg = ('output_heights must be a dictionary with variables as keys ' + f'and lists of heights as values. Received: {output_heights}.') assert output_heights is None or isinstance(output_heights, dict), msg self.new_heights = output_heights or self.DEFAULT_OUTPUT_HEIGHTS @@ -83,11 +82,9 @@ def __init__( msg = f'{self.infile} does not exist. Skipping.' assert os.path.exists(self.infile), msg - msg = ( - f'Initializing {self.__class__.__name__} with infile={infile}, ' - f'outfile={outfile}, new_heights={self.new_heights}, ' - f'variables={variables}.' - ) + msg = (f'Initializing {self.__class__.__name__} with infile={infile}, ' + f'outfile={outfile}, new_heights={self.new_heights}, ' + f'variables={variables}.') logger.info(msg) def _load_single_var(self, variable): @@ -110,9 +107,9 @@ def _load_single_var(self, variable): logger.info(f'Loading {self.infile} for {variable}.') with xr.open_dataset(self.infile) as res: gp = res['zg'].values - sfc_hgt = np.repeat( - res['orog'].values[:, np.newaxis, ...], gp.shape[1], axis=1 - ) + sfc_hgt = np.repeat(res['orog'].values[:, np.newaxis, ...], + gp.shape[1], + axis=1) heights = gp - sfc_hgt input_heights = [] @@ -125,9 +122,9 @@ def _load_single_var(self, variable): height_arr = [] shape = (heights.shape[0], 1, *heights.shape[2:]) for height in input_heights: - var_arr.append( - res[f'{variable}_{height}m'].values[:, np.newaxis, ...] - ) + var_arr.append(res[f'{variable}_{height}m'].values[:, + np.newaxis, + ...]) height_arr.append(np.full(shape, height, dtype=np.float32)) if variable in res: @@ -162,8 +159,7 @@ def interpolate_vars(self, max_workers=None): if var not in ('u', 'v'): max_log_height = -np.inf logger.info( - f'Interpolating {var} to heights = {self.new_heights[var]}.' - ) + f'Interpolating {var} to heights = {self.new_heights[var]}.') self.new_data[var] = self.interp_var_to_height( var_array=arrs['data'], @@ -208,6 +204,38 @@ def save_output(self): ds.close() logger.info(f'Saved interpolated output to {self.outfile}.') + @classmethod + def init_dims(cls, old_ds, new_ds, dims): + """Initialize dimensions in new dataset from old dataset + + Parameters + ---------- + old_ds : Dataset + Dataset() object from old file + new_ds : Dataset + Dataset() object for new file + dims : tuple + Tuple of dimensions. e.g. ('time', 'latitude', 'longitude') + + Returns + ------- + new_ds : Dataset + Dataset() object for new file with dimensions initialized. + """ + for var in dims: + new_ds.createDimension(var, len(old_ds[var])) + _ = new_ds.createVariable(var, old_ds[var].dtype, dimensions=var) + new_ds[var][:] = old_ds[var][:] + new_ds[var].units = old_ds[var].units + return new_ds + + @classmethod + def get_tmp_file(cls, file): + """Get temp file for given file. Then only needed variables will be + written to the given file.""" + tmp_file = file.replace('.nc', '_tmp.nc') + return tmp_file + @classmethod def run( cls, @@ -250,8 +278,7 @@ def run( ) if os.path.exists(outfile) and not overwrite: logger.info( - f'{outfile} already exists and overwrite=False. ' 'Skipping.' - ) + f'{outfile} already exists and overwrite=False. Skipping.') else: log_interp.load() log_interp.interpolate_vars(max_workers=max_workers) @@ -296,8 +323,7 @@ def run_multiple( if max_workers == 1: for _, file in enumerate(infiles): outfile = os.path.basename(file).replace( - '.nc', '_all_interp.nc' - ) + '.nc', '_all_interp.nc') outfile = os.path.join(out_dir, outfile) cls.run( file, @@ -312,36 +338,29 @@ def run_multiple( with ThreadPoolExecutor(max_workers=max_workers) as exe: for i, file in enumerate(infiles): outfile = os.path.basename(file).replace( - '.nc', '_all_interp.nc' - ) + '.nc', '_all_interp.nc') outfile = os.path.join(out_dir, outfile) futures.append( - exe.submit( - cls.run, - file, - outfile, - output_heights=output_heights, - variables=variables, - max_log_height=max_log_height, - overwrite=overwrite, - ) - ) + exe.submit(cls.run, + file, + outfile, + output_heights=output_heights, + variables=variables, + max_log_height=max_log_height, + overwrite=overwrite)) logger.info( - f'{i + 1} of {len(infiles)} futures submitted.' - ) + f'{i + 1} of {len(infiles)} futures submitted.') for i, future in enumerate(as_completed(futures)): future.result() logger.info(f'{i + 1} of {len(futures)} futures complete.') @classmethod - def pbl_interp_to_height( - cls, - lev_array, - var_array, - levels, - fixed_level_mask=None, - max_log_height=100, - ): + def pbl_interp_to_height(cls, + lev_array, + var_array, + levels, + fixed_level_mask=None, + max_log_height=100): """Fit ws log law to data below max_log_height. Parameters @@ -386,19 +405,14 @@ def ws_log_profile(z, a, b): var_mask = (0 < lev_array_samp) & (lev_array_samp <= max_log_height) try: - popt, _ = curve_fit( - ws_log_profile, - lev_array_samp[var_mask], - var_array_samp[var_mask], - ) + popt, _ = curve_fit(ws_log_profile, lev_array_samp[var_mask], + var_array_samp[var_mask]) log_ws = ws_log_profile(levels[lev_mask], *popt) except Exception as e: - msg = ( - 'Log interp failed with (h, ws) = ' - f'({lev_array_samp[var_mask]}, ' - f'{var_array_samp[var_mask]}). {e} ' - 'Using linear interpolation.' - ) + msg = ('Log interp failed with (h, ws) = ' + f'({lev_array_samp[var_mask]}, ' + f'{var_array_samp[var_mask]}). {e} ' + 'Using linear interpolation.') good = False logger.warning(msg) warn(msg) @@ -410,14 +424,12 @@ def ws_log_profile(z, a, b): return log_ws, good @classmethod - def _interp_var_to_height( - cls, - lev_array, - var_array, - levels, - fixed_level_mask=None, - max_log_height=100, - ): + def _interp_var_to_height(cls, + lev_array, + var_array, + levels, + fixed_level_mask=None, + max_log_height=100): """Fit ws log law to wind data below max_log_height and linearly interpolate data above. Linearly interpolate non wind data. @@ -454,43 +466,35 @@ def _interp_var_to_height( good = True hgt_check = any(levels < max_log_height) and any( - lev_array < max_log_height - ) + lev_array < max_log_height) if hgt_check: log_ws, good = cls.pbl_interp_to_height( lev_array, var_array, levels, fixed_level_mask=fixed_level_mask, - max_log_height=max_log_height, - ) + max_log_height=max_log_height) if any(levels > max_log_height): lev_mask = levels >= max_log_height var_mask = lev_array >= max_log_height if len(lev_array[var_mask]) > 1: - lin_ws = interp1d( - lev_array[var_mask], - var_array[var_mask], - fill_value='extrapolate', - )(levels[lev_mask]) + lin_ws = interp1d(lev_array[var_mask], + var_array[var_mask], + fill_value='extrapolate')(levels[lev_mask]) elif len(lev_array) > 1: - msg = ( - 'Requested interpolation levels are outside the ' - f'available range: lev_array={lev_array}, ' - f'levels={levels}. Using linear extrapolation.' - ) - lin_ws = interp1d( - lev_array, var_array, fill_value='extrapolate' - )(levels[lev_mask]) + msg = ('Requested interpolation levels are outside the ' + f'available range: lev_array={lev_array}, ' + f'levels={levels}. Using linear extrapolation.') + lin_ws = interp1d(lev_array, + var_array, + fill_value='extrapolate')(levels[lev_mask]) good = False logger.warning(msg) warn(msg) else: - msg = ( - 'Data seems to be all NaNs. Something may have gone ' - 'wrong during download.' - ) + msg = ('Data seems to be all NaNs. Something may have gone ' + 'wrong during download.') raise OSError(msg) if log_ws is not None and lin_ws is not None: @@ -503,10 +507,8 @@ def _interp_var_to_height( out = lin_ws if log_ws is None and lin_ws is None: - msg = ( - f'No interpolation was performed for lev_array={lev_array} ' - f'and levels={levels}' - ) + msg = (f'No interpolation was performed for lev_array={lev_array} ' + f'and levels={levels}') raise RuntimeError(msg) return out, good @@ -545,15 +547,13 @@ def _get_timestep_interp_input(cls, lev_array, var_array, idt): return h_t, var_t, mask @classmethod - def interp_single_ts( - cls, - hgt_t, - var_t, - mask, - levels, - fixed_level_mask=None, - max_log_height=100, - ): + def interp_single_ts(cls, + hgt_t, + var_t, + mask, + levels, + fixed_level_mask=None, + max_log_height=100): """Perform interpolation for a single timestep specified by the index idt @@ -599,15 +599,13 @@ def interp_single_ts( return np.array(out_array), np.array(checks) @classmethod - def interp_var_to_height( - cls, - var_array, - lev_array, - levels, - fixed_level_mask=None, - max_log_height=100, - max_workers=None, - ): + def interp_var_to_height(cls, + var_array, + lev_array, + levels, + fixed_level_mask=None, + max_log_height=100, + max_workers=None): """Interpolate data array to given level(s) based on h_array. Interpolation is done using windspeed log profile and is done for every 'z' column of [var, h] data. @@ -642,8 +640,7 @@ def interp_var_to_height( Array of interpolated values. """ lev_array, levels = Interpolator.prep_level_interp( - var_array, lev_array, levels - ) + var_array, lev_array, levels) array_shape = var_array.shape @@ -657,8 +654,7 @@ def interp_var_to_height( if max_workers == 1: for idt in range(array_shape[0]): h_t, v_t, mask = cls._get_timestep_interp_input( - lev_array, var_array, idt - ) + lev_array, var_array, idt) out, checks = cls.interp_single_ts( h_t, v_t, @@ -671,15 +667,13 @@ def interp_var_to_height( total_checks.append(checks) logger.info( - f'{idt + 1} of {array_shape[0]} timesteps finished.' - ) + f'{idt + 1} of {array_shape[0]} timesteps finished.') else: with ProcessPoolExecutor(max_workers=max_workers) as exe: for idt in range(array_shape[0]): h_t, v_t, mask = cls._get_timestep_interp_input( - lev_array, var_array, idt - ) + lev_array, var_array, idt) future = exe.submit( cls.interp_single_ts, h_t, @@ -691,8 +685,7 @@ def interp_var_to_height( ) futures[future] = idt logger.info( - f'{idt + 1} of {array_shape[0]} futures submitted.' - ) + f'{idt + 1} of {array_shape[0]} futures submitted.') for i, future in enumerate(as_completed(futures)): out, checks = future.result() out_array[:, futures[future], :] = out @@ -702,22 +695,16 @@ def interp_var_to_height( total_checks = np.concatenate(total_checks) good_count = total_checks.sum() total_count = len(total_checks) - logger.info( - 'Percent of points interpolated without issue: ' - f'{100 * good_count / total_count:.2f}' - ) + logger.info('Percent of points interpolated without issue: ' + f'{100 * good_count / total_count:.2f}') # Reshape out_array if isinstance(levels, (float, np.float32, int)): shape = (1, array_shape[-4], array_shape[-2], array_shape[-1]) out_array = out_array.T.reshape(shape) else: - shape = ( - len(levels), - array_shape[-4], - array_shape[-2], - array_shape[-1], - ) + shape = (len(levels), array_shape[-4], array_shape[-2], + array_shape[-1]) out_array = out_array.T.reshape(shape) return out_array diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index 24891b118..97bad45ea 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -1,33 +1,44 @@ """Code for regridding data from one list of coordinates to another""" -import numpy as np -from sklearn.neighbors import BallTree import logging -import psutil -from glob import glob -import pickle import os -import pandas as pd +import pickle +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime as dt -from concurrent.futures import as_completed, ThreadPoolExecutor +from glob import glob -from rex.utilities.fun_utils import get_fun_call_str +import numpy as np +import pandas as pd +import psutil from rex import MultiFileResource +from rex.utilities.fun_utils import get_fun_call_str +from sklearn.neighbors import BallTree from sup3r.postprocessing.file_handling import OutputMixIn, RexOutputs from sup3r.utilities import ModuleName -from sup3r.utilities.execution import DistributedProcess from sup3r.utilities.cli import BaseCLI +from sup3r.utilities.execution import DistributedProcess logger = logging.getLogger(__name__) -class TreeBuilder: - """TreeBuilder class for building ball tree and running all queries to - create full arrays of indices and distances for neighbor points +class Regridder: + """Basic Regridder class. Builds ball tree and runs all queries to + create full arrays of indices and distances for neighbor points. Computes + array of weights used to interpolate from old grid to new grid. """ - def __init__(self, source_meta, target_meta, cache_pattern=None, - leaf_size=4, k_neighbors=4, n_chunks=100, max_workers=None): + MIN_DISTANCE = 1e-12 + + def __init__( + self, + source_meta, + target_meta, + cache_pattern=None, + leaf_size=4, + k_neighbors=4, + n_chunks=100, + max_workers=None, + ): """Get weights and indices used to map from source grid to target grid Parameters @@ -53,6 +64,8 @@ def __init__(self, source_meta, target_meta, cache_pattern=None, to building full set of indices and distances for each target_meta coordinate. """ + logger.info('Initializing Regridder.') + self.cache_pattern = cache_pattern self.target_meta = target_meta self.source_meta = source_meta @@ -63,6 +76,7 @@ def __init__(self, source_meta, target_meta, cache_pattern=None, self.leaf_size = leaf_size self.distances = [None] * len(self.target_meta) self.indices = [None] * len(self.target_meta) + self._weights = None if self.cache_exists: self.load_cache() @@ -72,8 +86,16 @@ def __init__(self, source_meta, target_meta, cache_pattern=None, self.cache_all_queries() @classmethod - def run(cls, source_meta, target_meta, cache_pattern=None, - leaf_size=4, k_neighbors=4, n_chunks=100, max_workers=None): + def run( + cls, + source_meta, + target_meta, + cache_pattern=None, + leaf_size=4, + k_neighbors=4, + n_chunks=100, + max_workers=None, + ): """Query tree for every point in target_meta to get full set of indices and distances for the neighboring points in the source_meta. @@ -100,21 +122,44 @@ def run(cls, source_meta, target_meta, cache_pattern=None, to building full set of indices and distances for each target_meta coordinate. """ - tree_builder = cls(source_meta=source_meta, target_meta=target_meta, - cache_pattern=cache_pattern, leaf_size=leaf_size, - k_neighbors=k_neighbors, n_chunks=n_chunks, - max_workers=max_workers) - if not tree_builder.cache_exists: - tree_builder.get_all_queries(max_workers) - tree_builder.cache_all_queries() + regridder = cls( + source_meta=source_meta, + target_meta=target_meta, + cache_pattern=cache_pattern, + leaf_size=leaf_size, + k_neighbors=k_neighbors, + n_chunks=n_chunks, + max_workers=max_workers, + ) + if not regridder.cache_exists: + regridder.get_all_queries(max_workers) + regridder.cache_all_queries() + + @property + def weights(self): + """Get weights used for regridding""" + if self._weights is None: + dists = np.array(self.distances, dtype=np.float32) + mask = dists < self.MIN_DISTANCE + if mask.sum() > 0: + logger.info( + f'{np.sum(mask)} of {np.product(mask.shape)} ' + 'distances are zero.' + ) + dists[mask] = self.MIN_DISTANCE + weights = 1 / dists + self._weights = weights / np.sum(weights, axis=-1)[:, None] + return self._weights @property def cache_exists(self): """Check if cache exists before building tree.""" - cache_exists_check = (self.index_file is not None - and os.path.exists(self.index_file) - and self.distance_file is not None - and os.path.exists(self.distance_file)) + cache_exists_check = ( + self.index_file is not None + and os.path.exists(self.index_file) + and self.distance_file is not None + and os.path.exists(self.distance_file) + ) return cache_exists_check def build_tree(self): @@ -138,13 +183,13 @@ def get_all_queries(self, max_workers=None): self._parallel_queries(max_workers=max_workers) def _serial_queries(self): - """Get indices and distances for all points in target_meta, in serial - """ + """Get indices and distances for all points in target_meta, in + serial""" self.save_query(slice(None)) def _parallel_queries(self, max_workers=None): - """Get indices and distances for all points in target_meta, in serial - """ + """Get indices and distances for all points in target_meta, in + serial""" futures = {} now = dt.now() slices = np.arange(len(self.target_meta)) @@ -155,10 +200,13 @@ def _parallel_queries(self, max_workers=None): future = exe.submit(self.save_query, s_slice=s_slice) futures[future] = i mem = psutil.virtual_memory() - msg = ('Query futures submitted: {0} out of {1}. Current ' - 'memory usage is {2:.3f} GB out of {3:.3f} GB ' - 'total.'.format(i + 1, len(slices), mem.used / 1e9, - mem.total / 1e9)) + msg = ( + 'Query futures submitted: {} out of {}. Current ' + 'memory usage is {:.3f} GB out of {:.3f} GB ' + 'total.'.format( + i + 1, len(slices), mem.used / 1e9, mem.total / 1e9 + ) + ) logger.info(msg) logger.info(f'Submitted all query futures in {dt.now() - now}.') @@ -166,18 +214,21 @@ def _parallel_queries(self, max_workers=None): for i, future in enumerate(as_completed(futures)): idx = futures[future] mem = psutil.virtual_memory() - msg = ('Query futures completed: {0} out of ' - '{1}. Current memory usage is {2:.3f} ' - 'GB out of {3:.3f} GB total.'.format(i + 1, - len(futures), - mem.used / 1e9, - mem.total / 1e9)) + msg = ( + 'Query futures completed: {} out of ' + '{}. Current memory usage is {:.3f} ' + 'GB out of {:.3f} GB total.'.format( + i + 1, len(futures), mem.used / 1e9, mem.total / 1e9 + ) + ) logger.info(msg) try: future.result() except Exception as e: - msg = ('Failed to query coordinate chunk with ' - 'index={index}'.format(index=idx)) + msg = ( + 'Failed to query coordinate chunk with ' + 'index={index}'.format(index=idx) + ) logger.exception(msg) raise RuntimeError(msg) from e @@ -193,8 +244,9 @@ def load_cache(self): self.indices = pickle.load(f) with open(self.distance_file, 'rb') as f: self.distances = pickle.load(f) - logger.info(f'Loaded cache files: {self.index_file}, ' - f'{self.distance_file}') + logger.info( + f'Loaded cache files: {self.index_file}, ' f'{self.distance_file}' + ) def cache_all_queries(self): """Cache indices and distances from ball tree query""" @@ -203,8 +255,10 @@ def cache_all_queries(self): pickle.dump(self.indices, f, protocol=4) with open(self.distance_file, 'wb') as f: pickle.dump(self.distances, f, protocol=4) - logger.info(f'Saved cache files: {self.index_file}, ' - f'{self.distance_file}') + logger.info( + f'Saved cache files: {self.index_file}, ' + f'{self.distance_file}' + ) @property def index_file(self): @@ -259,27 +313,22 @@ def query_tree(self, s_slice): Array of indices for neighboring points for each point selected by s_slice. (n_ponts, k_neighbors) """ - return self.tree.query(self.get_spatial_chunk(s_slice), - k=self.k_neighbors) - + return self.tree.query( + self.get_spatial_chunk(s_slice), k=self.k_neighbors + ) -class Regridder(TreeBuilder): - """Regridder class for mapping list of coordinates to another. - Includes weights and indicies used to map from source grid to each point in - the new grid""" - - @staticmethod - def interpolate(distance_chunk, values): - """Interpolate to a new coordinate based on distances from that - coordinate and the values of the points at those distances + @classmethod + def interpolate(cls, distance_chunk, values): + """Interpolate to new coordinates based on distances from those + coordinates and the values of the points at those distances Parameters ---------- distance_chunk : ndarray Chunk of the full array of distances where distances[i] gives the - list of distances to the source coordinates to be used for - interpolation for the i-th coordinate in the target data. - (temporal, n_points, k_neighbors) + list of k_neighbors distances to the source coordinates to be used + for interpolation for the i-th coordinate in the target data. + (n_points, k_neighbors) values : ndarray Array of values corresponding to the point distances with shape (temporal, n_points, k_neighbors) @@ -287,24 +336,56 @@ def interpolate(distance_chunk, values): Returns ------- ndarray - Time series of values at interpolated point with shape + Time series of values at interpolated points with shape (temporal, n_points) """ - dists = np.array(distance_chunk) - min_dist = 1e-12 - mask = (dists < min_dist) + dists = np.array(distance_chunk, dtype=np.float32) + mask = dists < cls.MIN_DISTANCE if mask.sum() > 0: - logger.info(f'{np.sum(mask)} of {np.product(mask.shape)} ' - 'distances are zero.') - dists[mask] = min_dist + logger.info( + f'{np.sum(mask)} of {np.product(mask.shape)} ' + 'distances are zero.' + ) + dists[mask] = cls.MIN_DISTANCE weights = 1 / dists norm = np.sum(weights, axis=-1) out = np.einsum('ijk,jk->ij', values, weights) / norm return out + def __call__(self, data): + """Regrid given spatiotemporal data over entire grid + + Parameters + ---------- + data : ndarray + Spatiotemporal data to regrid to target_meta + (spatial_1, spatial_2, temporal) + + Returns + ------- + out : ndarray + Flattened regridded spatiotemporal data + (spatial, temporal) + """ + vals = np.concatenate( + [ + data[:, :, i].flatten()[self.indices][np.newaxis] + for i in range(data.shape[-1]) + ], + axis=0, + ) + out = np.einsum('ijk,jk->ij', vals, self.weights) + return out.T + + +class WindRegridder(Regridder): + """Class to regrid windspeed and winddirection. Includes methods for + converting windspeed and winddirection to U and V and inverting after + interpolation""" + @classmethod def get_source_values(cls, index_chunk, feature, source_files): - """Get values to use for interpolation + """Get values to use for interpolation from h5 source files Parameters ---------- @@ -325,19 +406,16 @@ def get_source_values(cls, index_chunk, feature, source_files): (temporal, n_points, k_neighbors) """ with MultiFileResource(source_files) as res: - shape = (len(res.time_index), len(index_chunk), - len(index_chunk[0])) + shape = ( + len(res.time_index), + len(index_chunk), + len(index_chunk[0]), + ) tmp = np.array(index_chunk).flatten() out = res[feature, :, tmp] out = out.reshape(shape) return out - -class WindRegridder(Regridder): - """Class to regrid windspeed and winddirection. Includes methods for - converting windspeed and winddirection to U and V and inverting after - interpolation""" - @classmethod def get_source_uv(cls, index_chunk, height, source_files): """Get u/v wind components from windspeed and winddirection @@ -352,7 +430,7 @@ def get_source_uv(cls, index_chunk, height, source_files): height : int Wind height level source_files : list - List of paths to source files + List of paths to h5 source files Returns ------- @@ -363,10 +441,12 @@ def get_source_uv(cls, index_chunk, height, source_files): Array of meridional wind values to use for interpolation with shape (temporal, n_points, k_neighbors) """ - ws = cls.get_source_values(index_chunk, f'windspeed_{height}m', - source_files) - wd = cls.get_source_values(index_chunk, f'winddirection_{height}m', - source_files) + ws = cls.get_source_values( + index_chunk, f'windspeed_{height}m', source_files + ) + wd = cls.get_source_values( + index_chunk, f'winddirection_{height}m', source_files + ) u = ws * np.sin(np.radians(wd)) v = ws * np.cos(np.radians(wd)) @@ -400,8 +480,9 @@ def invert_uv(cls, u, v): return ws, wd @classmethod - def regrid_coordinates(cls, index_chunk, distance_chunk, height, - source_files): + def regrid_coordinates( + cls, index_chunk, distance_chunk, height, source_files + ): """Regrid wind fields at given height for the requested coordinate index @@ -420,7 +501,7 @@ def regrid_coordinates(cls, index_chunk, distance_chunk, height, height : int Wind height level source_files : list - List of paths to source files + List of paths to h5 source files Returns ------- @@ -444,10 +525,20 @@ class RegridOutput(OutputMixIn, DistributedProcess): a new target grid. The interpolated data is then written to new files, with one file for each field (e.g. windspeed_100m).""" - def __init__(self, source_files, out_pattern, target_meta, heights, - cache_pattern=None, leaf_size=4, k_neighbors=4, - incremental=False, n_chunks=100, max_nodes=1, - worker_kwargs=None): + def __init__( + self, + source_files, + out_pattern, + target_meta, + heights, + cache_pattern=None, + leaf_size=4, + k_neighbors=4, + incremental=False, + n_chunks=100, + max_nodes=1, + worker_kwargs=None, + ): """ Parameters ---------- @@ -484,13 +575,17 @@ def __init__(self, source_files, out_pattern, target_meta, heights, worker_kwargs = worker_kwargs or {} self.regrid_workers = worker_kwargs.get('regrid_workers', None) self.query_workers = worker_kwargs.get('query_workers', None) - self.source_files = (source_files if isinstance(source_files, list) - else glob(source_files)) + self.source_files = ( + source_files + if isinstance(source_files, list) + else glob(source_files) + ) self.target_meta_path = target_meta self.target_meta = pd.read_csv(self.target_meta_path) self.target_meta['gid'] = np.arange(len(self.target_meta)) self.target_meta = self.target_meta.sort_values( - ['latitude', 'longitude'], ascending=[False, True]) + ['latitude', 'longitude'], ascending=[False, True] + ) self.heights = heights self.incremental = incremental self.out_pattern = out_pattern @@ -501,25 +596,32 @@ def __init__(self, source_files, out_pattern, target_meta, heights, self.source_meta = res.meta self.global_attrs = res.global_attrs - self.regridder = WindRegridder(self.source_meta, - self.target_meta, - leaf_size=leaf_size, - k_neighbors=k_neighbors, - cache_pattern=cache_pattern, - n_chunks=n_chunks, - max_workers=self.query_workers) - DistributedProcess.__init__(self, max_nodes=max_nodes, - n_chunks=n_chunks, - max_chunks=len(self.regridder.indices), - incremental=incremental) - - logger.info('Initializing RegridOutput with ' - f'source_files={self.source_files}, ' - f'out_pattern={self.out_pattern}, ' - f'heights={self.heights}, ' - f'target_meta={target_meta}, ' - f'k_neighbors={k_neighbors}, and ' - f'n_chunks={n_chunks}.') + self.regridder = WindRegridder( + self.source_meta, + self.target_meta, + leaf_size=leaf_size, + k_neighbors=k_neighbors, + cache_pattern=cache_pattern, + n_chunks=n_chunks, + max_workers=self.query_workers, + ) + DistributedProcess.__init__( + self, + max_nodes=max_nodes, + n_chunks=n_chunks, + max_chunks=len(self.regridder.indices), + incremental=incremental, + ) + + logger.info( + 'Initializing RegridOutput with ' + f'source_files={self.source_files}, ' + f'out_pattern={self.out_pattern}, ' + f'heights={self.heights}, ' + f'target_meta={target_meta}, ' + f'k_neighbors={k_neighbors}, and ' + f'n_chunks={n_chunks}.' + ) logger.info(f'Max memory usage: {self.max_memory:.3f} GB.') @property @@ -561,8 +663,10 @@ def meta_chunks(self): @property def out_files(self): """Get list of output files for each spatial chunk""" - return [self.out_pattern.format(file_id=str(i).zfill(6)) - for i in range(self.chunks)] + return [ + self.out_pattern.format(file_id=str(i).zfill(6)) + for i in range(self.chunks) + ] @property def output_features(self): @@ -584,30 +688,33 @@ def get_node_cmd(cls, config): run regridding. """ - import_str = ('from sup3r.utilities.regridder import RegridOutput;\n' - 'from rex import init_logger;\n' - 'import time;\n' - 'from reV.pipeline.status import Status;\n') + import_str = ( + 'from sup3r.utilities.regridder import RegridOutput;\n' + 'from rex import init_logger;\n' + 'import time;\n' + 'from sup3r.pipeline import Status;\n' + ) regrid_fun_str = get_fun_call_str(cls, config) node_index = config['node_index'] log_file = config.get('log_file', None) log_level = config.get('log_level', 'INFO') - log_arg_str = (f'"sup3r", log_level="{log_level}"') + log_arg_str = f'"sup3r", log_level="{log_level}"' if log_file is not None: log_arg_str += f', log_file="{log_file}"' - cmd = (f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"regrid_output = {regrid_fun_str};\n" - f"regrid_output.run({node_index});\n" - "t_elap = time.time() - t0;\n" - ) + cmd = ( + f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"regrid_output = {regrid_fun_str};\n" + f"regrid_output.run({node_index});\n" + "t_elap = time.time() - t0;\n" + ) cmd = BaseCLI.add_status_cmd(config, ModuleName.REGRID, cmd) - cmd += (";\'\n") + cmd += ";\'\n" return cmd.replace('\\', '/') @@ -624,12 +731,15 @@ def run(self, node_index): return if self.regrid_workers == 1: - self._run_serial(source_files=self.source_files, - node_index=node_index) + self._run_serial( + source_files=self.source_files, node_index=node_index + ) else: - self._run_parallel(source_files=self.source_files, - node_index=node_index, - max_workers=self.regrid_workers) + self._run_parallel( + source_files=self.source_files, + node_index=node_index, + max_workers=self.regrid_workers, + ) def _run_serial(self, source_files, node_index): """Regrid data and write to output file, in serial. @@ -644,15 +754,21 @@ def _run_serial(self, source_files, node_index): """ logger.info('Regridding all coordinates in serial.') for i, chunk_index in enumerate(self.node_chunks[node_index]): - self.write_coordinates(source_files=source_files, - chunk_index=chunk_index) + self.write_coordinates( + source_files=source_files, chunk_index=chunk_index + ) mem = psutil.virtual_memory() - msg = ('Coordinate chunks regridded: {0} out of {1}. ' - 'Current memory usage is {2:.3f} GB out of {3:.3f} ' - 'GB total.'.format(i + 1, - len(self.node_chunks[node_index]), - mem.used / 1e9, mem.total / 1e9)) + msg = ( + 'Coordinate chunks regridded: {} out of {}. ' + 'Current memory usage is {:.3f} GB out of {:.3f} ' + 'GB total.'.format( + i + 1, + len(self.node_chunks[node_index]), + mem.used / 1e9, + mem.total / 1e9, + ) + ) logger.info(msg) def _run_parallel(self, source_files, node_index, max_workers=None): @@ -673,13 +789,16 @@ def _run_parallel(self, source_files, node_index, max_workers=None): logger.info('Regridding all coordinates in parallel.') with ThreadPoolExecutor(max_workers=max_workers) as exe: for i, chunk_index in enumerate(self.node_chunks[node_index]): - future = exe.submit(self.write_coordinates, - source_files=source_files, - chunk_index=chunk_index) + future = exe.submit( + self.write_coordinates, + source_files=source_files, + chunk_index=chunk_index, + ) futures[future] = chunk_index mem = psutil.virtual_memory() - msg = ('Regrid futures submitted: {0} out of {1}'.format( - i + 1, len(self.node_chunks[node_index]))) + msg = 'Regrid futures submitted: {} out of {}'.format( + i + 1, len(self.node_chunks[node_index]) + ) logger.info(msg) logger.info(f'Submitted all regrid futures in {dt.now() - now}.') @@ -687,17 +806,26 @@ def _run_parallel(self, source_files, node_index, max_workers=None): for i, future in enumerate(as_completed(futures)): idx = futures[future] mem = psutil.virtual_memory() - msg = ('Regrid futures completed: {0} out of {1}, in {2}. ' - 'Current memory usage is {3:.3f} GB out of {4:.3f} GB ' - 'total.'.format(i + 1, len(futures), dt.now() - now, - mem.used / 1e9, mem.total / 1e9)) + msg = ( + 'Regrid futures completed: {} out of {}, in {}. ' + 'Current memory usage is {:.3f} GB out of {:.3f} GB ' + 'total.'.format( + i + 1, + len(futures), + dt.now() - now, + mem.used / 1e9, + mem.total / 1e9, + ) + ) logger.info(msg) try: future.result() except Exception as e: - msg = ('Falied to regrid coordinate chunks with ' - 'index={index}'.format(index=idx)) + msg = ( + 'Falied to regrid coordinate chunks with ' + 'index={index}'.format(index=idx) + ) logger.exception(msg) raise RuntimeError(msg) from e @@ -726,15 +854,24 @@ def write_coordinates(self, source_files, chunk_index): fh.run_attrs = self.global_attrs for height in self.heights: ws, wd = self.regridder.regrid_coordinates( - index_chunk=index_chunk, distance_chunk=distance_chunk, - height=height, source_files=source_files) + index_chunk=index_chunk, + distance_chunk=distance_chunk, + height=height, + source_files=source_files, + ) features = [f'windspeed_{height}m', f'winddirection_{height}m'] for dset, data in zip(features, [ws, wd]): attrs, dtype = self.get_dset_attrs(dset) - fh.add_dataset(tmp_file, dset, data, dtype=dtype, - attrs=attrs, chunks=attrs['chunks']) + fh.add_dataset( + tmp_file, + dset, + data, + dtype=dtype, + attrs=attrs, + chunks=attrs['chunks'], + ) logger.info(f'Added {features} to {out_file}') os.replace(tmp_file, out_file) diff --git a/sup3r/utilities/topo.py b/sup3r/utilities/topo.py index 6728afbb3..ac01cecc1 100644 --- a/sup3r/utilities/topo.py +++ b/sup3r/utilities/topo.py @@ -1,17 +1,18 @@ """Sup3r topography utilities""" -import numpy as np import logging -from scipy.spatial import KDTree -from rex import Resource from abc import ABC, abstractmethod +import numpy as np +from rex import Resource +from scipy.spatial import KDTree + import sup3r.preprocessing.data_handling -from sup3r.preprocessing.data_handling import DataHandlerNC, DataHandlerH5 from sup3r.postprocessing.file_handling import OutputHandler +from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5 +from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC from sup3r.utilities.utilities import get_source_type - logger = logging.getLogger(__name__) @@ -22,9 +23,19 @@ class TopoExtract(ABC): (e.g. WTK or NSRDB) """ - def __init__(self, file_paths, topo_source, s_enhance, agg_factor, - target=None, shape=None, raster_file=None, max_delta=20, - input_handler=None, ti_workers=1): + def __init__( + self, + file_paths, + topo_source, + s_enhance, + agg_factor, + target=None, + shape=None, + raster_file=None, + max_delta=20, + input_handler=None, + ti_workers=1, + ): """ Parameters ---------- @@ -92,24 +103,33 @@ class will output a topography raster corresponding to the elif in_type == 'h5': input_handler = DataHandlerH5 else: - msg = ('Did not recognize input type "{}" for file paths: {}' - .format(in_type, file_paths)) + msg = 'Did not recognize input type "{}" for file paths: {}'.format( + in_type, file_paths + ) logger.error(msg) raise RuntimeError(msg) elif isinstance(input_handler, str): - input_handler = getattr(sup3r.preprocessing.data_handling, - input_handler, None) + input_handler = getattr( + sup3r.preprocessing.data_handling, input_handler, None + ) if input_handler is None: - msg = ('Could not find requested data handler class ' - f'"{input_handler}" in ' - 'sup3r.preprocessing.data_handling.') + msg = ( + 'Could not find requested data handler class ' + f'"{input_handler}" in ' + 'sup3r.preprocessing.data_handling.' + ) logger.error(msg) raise KeyError(msg) self.input_handler = input_handler( - file_paths, [], target=target, shape=shape, - raster_file=raster_file, max_delta=max_delta, - worker_kwargs=dict(ti_workers=ti_workers)) + file_paths, + [], + target=target, + shape=shape, + raster_file=raster_file, + max_delta=max_delta, + worker_kwargs=dict(ti_workers=ti_workers), + ) @property @abstractmethod @@ -129,8 +149,10 @@ def lr_shape(self): @property def hr_shape(self): """Get the high-resolution spatial shape tuple""" - return (self._s_enhance * self.lr_lat_lon.shape[0], - self._s_enhance * self.lr_lat_lon.shape[1]) + return ( + self._s_enhance * self.lr_lat_lon.shape[0], + self._s_enhance * self.lr_lat_lon.shape[1], + ) @property def lr_lat_lon(self): @@ -156,8 +178,9 @@ def hr_lat_lon(self): """ if self._hr_lat_lon is None: if self._s_enhance > 1: - self._hr_lat_lon = OutputHandler.get_lat_lon(self.lr_lat_lon, - self.hr_shape) + self._hr_lat_lon = OutputHandler.get_lat_lon( + self.lr_lat_lon, self.hr_shape + ) else: self._hr_lat_lon = self.lr_lat_lon return self._hr_lat_lon @@ -171,9 +194,13 @@ def tree(self): @property def nn(self): - """Get the nearest neighbor indices """ - ll2 = np.vstack((self.hr_lat_lon[:, :, 0].flatten(), - self.hr_lat_lon[:, :, 1].flatten())).T + """Get the nearest neighbor indices""" + ll2 = np.vstack( + ( + self.hr_lat_lon[:, :, 0].flatten(), + self.hr_lat_lon[:, :, 1].flatten(), + ) + ).T _, nn = self.tree.query(ll2, k=self._agg_factor) if len(nn.shape) == 1: nn = np.expand_dims(nn, 1) @@ -192,14 +219,24 @@ def hr_elev(self): elev = elev.reshape(self.hr_shape) hr_elev.append(elev) hr_elev = np.dstack(hr_elev).mean(axis=-1) - logger.info('Finished mapping topo raster from {}' - .format(self._topo_source)) + logger.info( + 'Finished mapping topo raster from {}'.format(self._topo_source) + ) return hr_elev @classmethod - def get_topo_raster(cls, file_paths, topo_source, s_enhance, - agg_factor, target=None, shape=None, raster_file=None, - max_delta=20, input_handler=None): + def get_topo_raster( + cls, + file_paths, + topo_source, + s_enhance, + agg_factor, + target=None, + shape=None, + raster_file=None, + max_delta=20, + input_handler=None, + ): """Get the topography raster corresponding to the spatially enhanced grid from the file_paths input @@ -251,9 +288,17 @@ class will output a topography raster corresponding to the topo_source_h5, usually meters. """ - te = cls(file_paths, topo_source, s_enhance, agg_factor, - target=target, shape=shape, raster_file=raster_file, - max_delta=max_delta, input_handler=input_handler) + te = cls( + file_paths, + topo_source, + s_enhance, + agg_factor, + target=target, + shape=shape, + raster_file=raster_file, + max_delta=max_delta, + input_handler=input_handler, + ) return te.hr_elev @@ -290,11 +335,15 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - logger.info('Getting topography for full domain from ' - f'{self._topo_source}') + logger.info( + 'Getting topography for full domain from ' f'{self._topo_source}' + ) self.source_handler = DataHandlerNC( - self._topo_source, features=['topography'], - worker_kwargs=dict(ti_workers=self.ti_workers), val_split=0.0) + self._topo_source, + features=['topography'], + worker_kwargs=dict(ti_workers=self.ti_workers), + val_split=0.0, + ) @property def source_elevation(self): diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 01a1d925a..d6875aa4a 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -148,7 +148,7 @@ def get_chunk_slices(arr_size, chunk_size, index_slice=slice(None)): """ indices = np.arange(0, arr_size) - indices = indices[index_slice.start:index_slice.stop] + indices = indices[slice(index_slice.start, index_slice.stop)] step = 1 if index_slice.step is None else index_slice.step slices = [] start = indices[0] @@ -167,8 +167,10 @@ def get_raster_shape(raster_index): """Method to get shape of raster_index""" if any(isinstance(r, slice) for r in raster_index): - shape = (raster_index[0].stop - raster_index[0].start, - raster_index[1].stop - raster_index[1].start) + shape = ( + raster_index[0].stop - raster_index[0].start, + raster_index[1].stop - raster_index[1].start, + ) else: shape = raster_index.shape return shape @@ -192,11 +194,13 @@ def get_wrf_date_range(files): end date """ - date_start = re.search(r'(\d{4}(-|_)\d+(-|_)\d+(-|_)\d+(:|_)\d+(:|_)\d+)', - files[0]) + date_start = re.search( + r'(\d{4}(-|_)\d+(-|_)\d+(-|_)\d+(:|_)\d+(:|_)\d+)', files[0] + ) date_start = date_start if date_start is None else date_start[0] - date_end = re.search(r'(\d{4}(-|_)\d+(-|_)\d+(-|_)\d+(:|_)\d+(:|_)\d+)', - files[-1]) + date_end = re.search( + r'(\d{4}(-|_)\d+(-|_)\d+(-|_)\d+(:|_)\d+(:|_)\d+)', files[-1] + ) date_end = date_end if date_end is None else date_end[0] date_start = date_start.replace(':', '_') @@ -256,10 +260,8 @@ def weighted_box_sampler(data, shape, weights): slices : list List of spatial slices [spatial_1, spatial_2] """ - max_cols = (data.shape[1] if data.shape[1] < shape[1] - else shape[1]) - max_rows = (data.shape[0] if data.shape[0] < shape[0] - else shape[0]) + max_cols = data.shape[1] if data.shape[1] < shape[1] else shape[1] + max_rows = data.shape[0] if data.shape[0] < shape[0] else shape[0] max_cols = data.shape[1] - max_cols + 1 max_rows = data.shape[0] - max_rows + 1 indices = np.arange(0, max_rows * max_cols) @@ -268,8 +270,10 @@ def weighted_box_sampler(data, shape, weights): for i, w in enumerate(weights): weight_list += [w] * len(chunks[i]) weight_list /= np.sum(weight_list) - msg = ('Must have a sample_shape with a number of elements greater than ' - 'or equal to the number of spatial weights.') + msg = ( + 'Must have a sample_shape with a number of elements greater than ' + 'or equal to the number of spatial weights.' + ) assert len(indices) >= len(weight_list), msg start = np.random.choice(indices, p=weight_list) row = start // max_cols @@ -307,8 +311,11 @@ def weighted_time_sampler(data, shape, weights): """ shape = data.shape[2] if data.shape[2] < shape else shape - t_indices = (np.arange(0, data.shape[2]) if shape == 1 - else np.arange(0, data.shape[2] - shape + 1)) + t_indices = ( + np.arange(0, data.shape[2]) + if shape == 1 + else np.arange(0, data.shape[2] - shape + 1) + ) t_chunks = np.array_split(t_indices, len(weights)) weight_list = [] @@ -365,18 +372,22 @@ def daily_time_sampler(data, shape, time_index): time slice with size shape of data starting at the beginning of the day """ - msg = (f'data {data.shape} and time index ({len(time_index)}) ' - 'shapes do not match, cannot sample daily data.') + msg = ( + f'data {data.shape} and time index ({len(time_index)}) ' + 'shapes do not match, cannot sample daily data.' + ) assert data.shape[2] == len(time_index), msg - ti_short = time_index[:-(shape - 1)] - midnight_ilocs = np.where((ti_short.hour == 0) - & (ti_short.minute == 0) - & (ti_short.second == 0))[0] + ti_short = time_index[: -(shape - 1)] + midnight_ilocs = np.where( + (ti_short.hour == 0) & (ti_short.minute == 0) & (ti_short.second == 0) + )[0] if not any(midnight_ilocs): - msg = ('Cannot sample time index of shape {} with requested daily ' - 'sample shape {}'.format(len(time_index), shape)) + msg = ( + 'Cannot sample time index of shape {} with requested daily ' + 'sample shape {}'.format(len(time_index), shape) + ) logger.error(msg) raise RuntimeError(msg) @@ -421,8 +432,10 @@ def nsrdb_sub_daily_sampler(data, shape, time_index, csr_ind=0): return tslice if night_mask.all(): - msg = (f'No daylight data found for tslice {tslice} ' - f'{time_index[tslice]}') + msg = ( + f'No daylight data found for tslice {tslice} ' + f'{time_index[tslice]}' + ) logger.warning(msg) warn(msg) return tslice @@ -466,7 +479,7 @@ def nsrdb_reduce_daily_data(data, shape, csr_ind=0): return data if night_mask.all(): - msg = (f'No daylight data found for data of shape {data.shape}') + msg = f'No daylight data found for data of shape {data.shape}' logger.warning(msg) warn(msg) return data @@ -596,7 +609,7 @@ def invert_uv(u, v, lat_lon): def temporal_coarsening(data, t_enhance=4, method='subsample'): - """"Coarsen data according to t_enhance resolution + """Coarsen data according to t_enhance resolution Parameters ---------- @@ -623,36 +636,69 @@ def temporal_coarsening(data, t_enhance=4, method='subsample'): elif method == 'average': coarse_data = np.nansum( data.reshape( - (data.shape[0], data.shape[1], - data.shape[2], -1, t_enhance, - data.shape[4])), axis=4) + ( + data.shape[0], + data.shape[1], + data.shape[2], + -1, + t_enhance, + data.shape[4], + ) + ), + axis=4, + ) coarse_data /= t_enhance elif method == 'max': coarse_data = np.max( data.reshape( - (data.shape[0], data.shape[1], - data.shape[2], -1, t_enhance, - data.shape[4])), axis=4) + ( + data.shape[0], + data.shape[1], + data.shape[2], + -1, + t_enhance, + data.shape[4], + ) + ), + axis=4, + ) elif method == 'min': coarse_data = np.min( data.reshape( - (data.shape[0], data.shape[1], - data.shape[2], -1, t_enhance, - data.shape[4])), axis=4) + ( + data.shape[0], + data.shape[1], + data.shape[2], + -1, + t_enhance, + data.shape[4], + ) + ), + axis=4, + ) elif method == 'total': coarse_data = np.nansum( data.reshape( - (data.shape[0], data.shape[1], - data.shape[2], -1, t_enhance, - data.shape[4])), axis=4) + ( + data.shape[0], + data.shape[1], + data.shape[2], + -1, + t_enhance, + data.shape[4], + ) + ), + axis=4, + ) else: - msg = ('Did not recognize temporal_coarsening method "{}", can ' - 'only accept one of: [subsample, average, total, max, min]' - .format(method)) + msg = ( + f'Did not recognize temporal_coarsening method "{method}", ' + 'can only accept one of: [subsample, average, total, max, min]' + ) logger.error(msg) raise KeyError(msg) @@ -663,7 +709,7 @@ def temporal_coarsening(data, t_enhance=4, method='subsample'): def temporal_simple_enhancing(data, t_enhance=4, mode='constant'): - """"Upsample data according to t_enhance resolution + """Upsample data according to t_enhance resolution Parameters ---------- @@ -672,6 +718,8 @@ def temporal_simple_enhancing(data, t_enhance=4, mode='constant'): (observations, spatial_1, spatial_2, temporal, features) t_enhance : int factor by which to enhance temporal dimension + mode : str + interpolation method for enhancement. Returns ------- @@ -684,22 +732,21 @@ def temporal_simple_enhancing(data, t_enhance=4, mode='constant'): elif t_enhance not in [None, 1] and len(data.shape) == 5: if mode == 'constant': enhancement = [1, 1, 1, t_enhance, 1] - enhanced_data = zoom(data, - enhancement, - order=0, - mode='nearest', - grid_mode=True) + enhanced_data = zoom( + data, enhancement, order=0, mode='nearest', grid_mode=True + ) elif mode == 'linear': index_t_hr = np.array(list(range(data.shape[3] * t_enhance))) index_t_lr = index_t_hr[::t_enhance] - enhanced_data = interp1d(index_t_lr, - data, - axis=3, - fill_value='extrapolate')(index_t_hr) + enhanced_data = interp1d( + index_t_lr, data, axis=3, fill_value='extrapolate' + )(index_t_hr) enhanced_data = np.array(enhanced_data, dtype=np.float32) elif len(data.shape) != 5: - msg = ('Data must be 5D to do temporal enhancing, but ' - f'received: {data.shape}') + msg = ( + 'Data must be 5D to do temporal enhancing, but ' + f'received: {data.shape}' + ) logger.error(msg) raise ValueError(msg) @@ -766,23 +813,27 @@ def smooth_data(low_res, training_features, smoothing_ignore, smoothing=None): """ if smoothing is not None: - feat_iter = [j for j in range(low_res.shape[-1]) - if training_features[j] not in smoothing_ignore] + feat_iter = [ + j + for j in range(low_res.shape[-1]) + if training_features[j] not in smoothing_ignore + ] for i in range(low_res.shape[0]): for j in feat_iter: if len(low_res.shape) == 5: for t in range(low_res.shape[-2]): low_res[i, ..., t, j] = gaussian_filter( - low_res[i, ..., t, j], smoothing, - mode='nearest') + low_res[i, ..., t, j], smoothing, mode='nearest' + ) else: low_res[i, ..., j] = gaussian_filter( - low_res[i, ..., j], smoothing, mode='nearest') + low_res[i, ..., j], smoothing, mode='nearest' + ) return low_res def spatial_coarsening(data, s_enhance=2, obs_axis=True): - """"Coarsen data according to s_enhance resolution + """Coarsen data according to s_enhance resolution Parameters ---------- @@ -806,54 +857,78 @@ def spatial_coarsening(data, s_enhance=2, obs_axis=True): """ if len(data.shape) < 3: - msg = ('Data must be 3D, 4D, or 5D to do spatial coarsening, but ' - f'received: {data.shape}') + msg = ( + 'Data must be 3D, 4D, or 5D to do spatial coarsening, but ' + f'received: {data.shape}' + ) logger.error(msg) raise ValueError(msg) if s_enhance is not None and s_enhance > 1: - bad1 = (obs_axis and (data.shape[1] % s_enhance != 0 - or data.shape[2] % s_enhance != 0)) - bad2 = (not obs_axis and (data.shape[0] % s_enhance != 0 - or data.shape[1] % s_enhance != 0)) + bad1 = obs_axis and ( + data.shape[1] % s_enhance != 0 or data.shape[2] % s_enhance != 0 + ) + bad2 = not obs_axis and ( + data.shape[0] % s_enhance != 0 or data.shape[1] % s_enhance != 0 + ) if bad1 or bad2: - msg = ('s_enhance must evenly divide grid size. ' - f'Received s_enhance: {s_enhance} with data shape: ' - f'{data.shape}') + msg = ( + 's_enhance must evenly divide grid size. ' + f'Received s_enhance: {s_enhance} with data shape: ' + f'{data.shape}' + ) logger.error(msg) raise ValueError(msg) if obs_axis and len(data.shape) == 5: - data = data.reshape(data.shape[0], - data.shape[1] // s_enhance, s_enhance, - data.shape[2] // s_enhance, s_enhance, - data.shape[3], - data.shape[4]) + data = data.reshape( + data.shape[0], + data.shape[1] // s_enhance, + s_enhance, + data.shape[2] // s_enhance, + s_enhance, + data.shape[3], + data.shape[4], + ) data = data.sum(axis=(2, 4)) / s_enhance**2 elif obs_axis and len(data.shape) == 4: - data = data.reshape(data.shape[0], - data.shape[1] // s_enhance, s_enhance, - data.shape[2] // s_enhance, s_enhance, - data.shape[3]) + data = data.reshape( + data.shape[0], + data.shape[1] // s_enhance, + s_enhance, + data.shape[2] // s_enhance, + s_enhance, + data.shape[3], + ) data = data.sum(axis=(2, 4)) / s_enhance**2 elif not obs_axis and len(data.shape) == 4: - data = data.reshape(data.shape[0] // s_enhance, s_enhance, - data.shape[1] // s_enhance, s_enhance, - data.shape[2], - data.shape[3]) + data = data.reshape( + data.shape[0] // s_enhance, + s_enhance, + data.shape[1] // s_enhance, + s_enhance, + data.shape[2], + data.shape[3], + ) data = data.sum(axis=(1, 3)) / s_enhance**2 elif not obs_axis and len(data.shape) == 3: - data = data.reshape(data.shape[0] // s_enhance, s_enhance, - data.shape[1] // s_enhance, s_enhance, - data.shape[2]) + data = data.reshape( + data.shape[0] // s_enhance, + s_enhance, + data.shape[1] // s_enhance, + s_enhance, + data.shape[2], + ) data = data.sum(axis=(1, 3)) / s_enhance**2 else: - msg = ('Data must be 3D, 4D, or 5D to do spatial coarsening, but ' - f'received: {data.shape}') + msg = ( + 'Data must be 3D, 4D, or 5D to do spatial coarsening, but ' + f'received: {data.shape}' + ) logger.error(msg) raise ValueError(msg) @@ -861,7 +936,7 @@ def spatial_coarsening(data, s_enhance=2, obs_axis=True): def spatial_simple_enhancing(data, s_enhance=2, obs_axis=True): - """"Simple enhancing according to s_enhance resolution + """Simple enhancing according to s_enhance resolution Parameters ---------- @@ -885,59 +960,53 @@ def spatial_simple_enhancing(data, s_enhance=2, obs_axis=True): """ if len(data.shape) < 3: - msg = ('Data must be 3D, 4D, or 5D to do spatial enhancing, but ' - f'received: {data.shape}') + msg = ( + 'Data must be 3D, 4D, or 5D to do spatial enhancing, but ' + f'received: {data.shape}' + ) logger.error(msg) raise ValueError(msg) if s_enhance is not None and s_enhance > 1: - if obs_axis and len(data.shape) == 5: enhancement = [1, s_enhance, s_enhance, 1, 1] - enhanced_data = zoom(data, - enhancement, - order=0, - mode='nearest', - grid_mode=True) + enhanced_data = zoom( + data, enhancement, order=0, mode='nearest', grid_mode=True + ) elif obs_axis and len(data.shape) == 4: enhancement = [1, s_enhance, s_enhance, 1] - enhanced_data = zoom(data, - enhancement, - order=0, - mode='nearest', - grid_mode=True) + enhanced_data = zoom( + data, enhancement, order=0, mode='nearest', grid_mode=True + ) elif not obs_axis and len(data.shape) == 4: enhancement = [s_enhance, s_enhance, 1, 1] - enhanced_data = zoom(data, - enhancement, - order=0, - mode='nearest', - grid_mode=True) + enhanced_data = zoom( + data, enhancement, order=0, mode='nearest', grid_mode=True + ) elif not obs_axis and len(data.shape) == 3: enhancement = [s_enhance, s_enhance, 1] - enhanced_data = zoom(data, - enhancement, - order=0, - mode='nearest', - grid_mode=True) + enhanced_data = zoom( + data, enhancement, order=0, mode='nearest', grid_mode=True + ) else: - msg = ('Data must be 3D, 4D, or 5D to do spatial enhancing, but ' - f'received: {data.shape}') + msg = ( + 'Data must be 3D, 4D, or 5D to do spatial enhancing, but ' + f'received: {data.shape}' + ) logger.error(msg) raise ValueError(msg) else: - enhanced_data = data return enhanced_data def lat_lon_coarsening(lat_lon, s_enhance=2): - """"Coarsen lat_lon according to s_enhance resolution + """Coarsen lat_lon according to s_enhance resolution Parameters ---------- @@ -952,10 +1021,10 @@ def lat_lon_coarsening(lat_lon, s_enhance=2): coarse_lat_lon : np.ndarray 2D array with same dimensions as lat_lon with new coarse resolution """ - coarse_lat_lon = lat_lon.reshape(-1, s_enhance, - lat_lon.shape[1] // s_enhance, - s_enhance, 2).sum((3, 1)) - coarse_lat_lon /= (s_enhance * s_enhance) + coarse_lat_lon = lat_lon.reshape( + -1, s_enhance, lat_lon.shape[1] // s_enhance, s_enhance, 2 + ).sum((3, 1)) + coarse_lat_lon /= s_enhance * s_enhance return coarse_lat_lon @@ -991,7 +1060,7 @@ def potential_temperature(T, P): ndarray Potential temperature """ - out = (T + np.float32(273.15)) + out = T + np.float32(273.15) out *= (np.float32(100000) / P) ** np.float32(0.286) return out @@ -1039,8 +1108,9 @@ def potential_temperature_difference(T_top, P_top, T_bottom, P_bottom): ndarray Difference in potential temperature between top and bottom levels """ - return (potential_temperature(T_top, P_top) - - potential_temperature(T_bottom, P_bottom)) + return potential_temperature(T_top, P_top) - potential_temperature( + T_bottom, P_bottom + ) def potential_temperature_average(T_top, P_top, T_bottom, P_bottom): @@ -1067,8 +1137,10 @@ def potential_temperature_average(T_top, P_top, T_bottom, P_bottom): Average of potential temperature between top and bottom levels """ - return ((potential_temperature(T_top, P_top) - + potential_temperature(T_bottom, P_bottom)) / np.float32(2.0)) + return ( + potential_temperature(T_top, P_top) + + potential_temperature(T_bottom, P_bottom) + ) / np.float32(2.0) def inverse_mo_length(U_star, flux_surf): @@ -1090,8 +1162,8 @@ def inverse_mo_length(U_star, flux_surf): Inverse Monin - Obukhov Length """ - denom = -U_star ** 3 * 300 - numer = (0.41 * 9.81 * flux_surf) + denom = -(U_star**3) * 300 + numer = 0.41 * 9.81 * flux_surf return numer / denom @@ -1123,16 +1195,15 @@ def bvf_squared(T_top, T_bottom, P_top, P_bottom, delta_h): """ bvf2 = np.float32(9.81 / delta_h) - bvf2 *= potential_temperature_difference( - T_top, P_top, T_bottom, P_bottom) - bvf2 /= potential_temperature_average( - T_top, P_top, T_bottom, P_bottom) + bvf2 *= potential_temperature_difference(T_top, P_top, T_bottom, P_bottom) + bvf2 /= potential_temperature_average(T_top, P_top, T_bottom, P_bottom) return bvf2 -def gradient_richardson_number(T_top, T_bottom, P_top, P_bottom, U_top, - U_bottom, V_top, V_bottom, delta_h): +def gradient_richardson_number( + T_top, T_bottom, P_top, P_bottom, U_top, U_bottom, V_top, V_bottom, delta_h +): """Formula for the gradient richardson number - related to the bouyant production or consumption of turbulence divided by the shear production of turbulence. Used to indicate dynamic stability @@ -1170,10 +1241,9 @@ def gradient_richardson_number(T_top, T_bottom, P_top, P_bottom, U_top, ws_grad = (U_top - U_bottom) ** 2 ws_grad += (V_top - V_bottom) ** 2 - ws_grad /= delta_h ** 2 + ws_grad /= delta_h**2 ws_grad[ws_grad < 1e-6] = 1e-6 - Ri = bvf_squared( - T_top, T_bottom, P_top, P_bottom, delta_h) / ws_grad + Ri = bvf_squared(T_top, T_bottom, P_top, P_bottom, delta_h) / ws_grad del ws_grad return Ri @@ -1193,8 +1263,9 @@ def nn_fill_array(array): """ nan_mask = np.isnan(array) - indices = nd.distance_transform_edt(nan_mask, return_distances=False, - return_indices=True) + indices = nd.distance_transform_edt( + nan_mask, return_distances=False, return_indices=True + ) array = array[tuple(indices)] return array @@ -1215,9 +1286,10 @@ def ignore_case_path_fetch(fp): dirname = os.path.dirname(fp) basename = os.path.basename(fp) - for file in os.listdir(dirname): - if fnmatch(file.lower(), basename.lower()): - return os.path.join(dirname, file) + if os.path.exists(dirname): + for file in os.listdir(dirname): + if fnmatch(file.lower(), basename.lower()): + return os.path.join(dirname, file) return None @@ -1270,11 +1342,13 @@ def rotor_equiv_ws(data, heights): rotor_center = np.mean(heights) rel_heights = [h - rotor_center for h in heights] - areas = [rotor_area(rel_heights[i], rel_heights[i + 1]) - for i in range(len(rel_heights) - 1)] + areas = [ + rotor_area(rel_heights[i], rel_heights[i + 1]) + for i in range(len(rel_heights) - 1) + ] total_area = np.sum(areas) areas /= total_area - rews = np.zeros(data[list(data.keys())[0]].shape) + rews = np.zeros(data[next(iter(data.keys()))].shape) for i in range(len(heights) - 1): ws_0 = data[f'windspeed_{heights[i]}m'] ws_1 = data[f'windspeed_{heights[i + 1]}m'] @@ -1282,7 +1356,7 @@ def rotor_equiv_ws(data, heights): wd_1 = data[f'winddirection_{heights[i + 1]}m'] ws_cos_0 = np.cos(np.radians(wd_0)) * ws_0 ws_cos_1 = np.cos(np.radians(wd_1)) * ws_1 - rews += areas[i] * (ws_cos_0 + ws_cos_1)**3 + rews += areas[i] * (ws_cos_0 + ws_cos_1) ** 3 rews = 0.5 * np.cbrt(rews) return rews @@ -1350,19 +1424,25 @@ def get_input_handler_class(file_paths, input_handler_name): elif input_type == 'h5': input_handler_name = 'DataHandlerH5' - logger.info('"input_handler" arg was not provided. Using ' - f'"{input_handler_name}". If this is ' - 'incorrect, please provide ' - 'input_handler="DataHandlerName".') + logger.info( + '"input_handler" arg was not provided. Using ' + f'"{input_handler_name}". If this is ' + 'incorrect, please provide ' + 'input_handler="DataHandlerName".' + ) if isinstance(input_handler_name, str): import sup3r.preprocessing.data_handling - HandlerClass = getattr(sup3r.preprocessing.data_handling, - input_handler_name, None) + + HandlerClass = getattr( + sup3r.preprocessing.data_handling, input_handler_name, None + ) if HandlerClass is None: - msg = ('Could not find requested data handler class ' - f'"{input_handler_name}" in sup3r.preprocessing.data_handling.') + msg = ( + 'Could not find requested data handler class ' + f'"{input_handler_name}" in sup3r.preprocessing.data_handling.' + ) logger.error(msg) raise KeyError(msg) @@ -1448,8 +1528,9 @@ def st_interp(low, s_enhance, t_enhance, t_centered=False): new_t += 5 / hr_t # set RegularGridInterpolator to do extrapolation - interp = RegularGridInterpolator((y, x, t), low, bounds_error=False, - fill_value=None) + interp = RegularGridInterpolator( + (y, x, t), low, bounds_error=False, fill_value=None + ) # perform interp X, Y, T = np.meshgrid(new_x, new_y, new_t) diff --git a/tests/data/test_era5_co_2012.nc b/tests/data/test_era5_co_2012.nc new file mode 100644 index 000000000..82576f2c2 Binary files /dev/null and b/tests/data/test_era5_co_2012.nc differ diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py new file mode 100644 index 000000000..02669afcc --- /dev/null +++ b/tests/data_handling/test_dual_data_handling.py @@ -0,0 +1,368 @@ +# -*- coding: utf-8 -*- +"""Test the basic training of super resolution GAN""" +import os +import tempfile + +import matplotlib.pyplot as plt +import numpy as np +from rex import init_logger + +from sup3r import TEST_DATA_DIR +from sup3r.preprocessing.data_handling.dual_data_handling import ( + DualDataHandler, +) +from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5 +from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC +from sup3r.preprocessing.dual_batch_handling import ( + DualBatchHandler, + SpatialDualBatchHandler, +) +from sup3r.utilities.utilities import spatial_coarsening + +FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') +FP_ERA = os.path.join(TEST_DATA_DIR, 'test_era5_co_2012.nc') +TARGET_COORD = (39.01, -105.15) +FEATURES = ['U_100m', 'V_100m'] + + +def test_dual_data_handler( + log=True, full_shape=(20, 20), sample_shape=(10, 10, 1), plot=True +): + """Test basic spatial model training with only gen content loss.""" + if log: + init_logger('sup3r', log_level='DEBUG') + + # need to reduce the number of temporal examples to test faster + hr_handler = DataHandlerH5( + FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC( + FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // 2, sample_shape[1] // 2, 1), + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + + dual_handler = DualDataHandler( + hr_handler, lr_handler, s_enhance=2, t_enhance=1, val_split=0.1 + ) + + batch_handler = SpatialDualBatchHandler( + [dual_handler], batch_size=2, s_enhance=2, n_batches=10 + ) + + if plot: + for i, batch in enumerate(batch_handler): + fig, ax = plt.subplots(1, 2, figsize=(5, 10)) + fig.suptitle(f'High vs Low Res ({dual_handler.features[-1]})') + ax[0].set_title('High Res') + ax[0].imshow(np.mean(batch.high_res[..., -1], axis=0)) + ax[1].set_title('Low Res') + ax[1].imshow(np.mean(batch.low_res[..., -1], axis=0)) + fig.savefig( + f'./high_vs_low_{str(i).zfill(3)}.png', bbox_inches='tight' + ) + + +def test_regrid_caching( + log=True, full_shape=(20, 20), sample_shape=(10, 10, 1) +): + """Test caching and loading of regridded data""" + if log: + init_logger('sup3r', log_level='DEBUG') + + # need to reduce the number of temporal examples to test faster + with tempfile.TemporaryDirectory() as td: + hr_handler = DataHandlerH5( + FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC( + FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // 2, sample_shape[1] // 2, 1), + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + old_dh = DualDataHandler( + hr_handler, + lr_handler, + s_enhance=2, + t_enhance=1, + val_split=0.1, + regrid_cache_pattern=f'{td}/cache.pkl', + ) + + # Load handlers again + hr_handler = DataHandlerH5( + FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC( + FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // 2, sample_shape[1] // 2, 1), + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + new_dh = DualDataHandler( + hr_handler, + lr_handler, + s_enhance=2, + t_enhance=1, + val_split=0.1, + regrid_cache_pattern=f'{td}/cache.pkl', + ) + assert np.array_equal(old_dh.lr_data, new_dh.lr_data) + assert np.array_equal(old_dh.hr_data, new_dh.hr_data) + + +def test_st_dual_batch_handler( + log=False, full_shape=(20, 20), sample_shape=(10, 10, 4) +): + """Test spatiotemporal dual batch handler.""" + t_enhance = 2 + s_enhance = 2 + + if log: + init_logger('sup3r', log_level='DEBUG') + + # need to reduce the number of temporal examples to test faster + hr_handler = DataHandlerH5( + FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC( + FP_ERA, + FEATURES, + sample_shape=( + sample_shape[0] // s_enhance, + sample_shape[1] // s_enhance, + sample_shape[2] // t_enhance, + ), + temporal_slice=slice(None, None, t_enhance * 10), + worker_kwargs=dict(max_workers=1), + ) + + dual_handler = DualDataHandler( + hr_handler, + lr_handler, + s_enhance=s_enhance, + t_enhance=t_enhance, + val_split=0.1, + ) + + batch_handler = DualBatchHandler( + [dual_handler], + batch_size=2, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=10, + ) + + for batch in batch_handler: + for i, index in enumerate(batch_handler.current_batch_indices): + hr_index = index['hr_index'] + lr_index = index['lr_index'] + + coarse_lat_lon = spatial_coarsening( + dual_handler.hr_lat_lon[hr_index[:2]], obs_axis=False + ) + lr_lat_lon = dual_handler.lr_lat_lon[lr_index[:2]] + assert np.array_equal(coarse_lat_lon, lr_lat_lon) + + coarse_ti = dual_handler.hr_time_index[hr_index[2]][::t_enhance] + lr_ti = dual_handler.lr_time_index[lr_index[2]] + assert np.array_equal(coarse_ti.values, lr_ti.values) + + assert np.array_equal( + batch.high_res[i], + dual_handler.hr_data[hr_index], + ) + assert np.array_equal( + batch.low_res[i], + dual_handler.lr_data[lr_index], + ) + + +def test_spatial_dual_batch_handler( + log=False, full_shape=(20, 20), sample_shape=(10, 10, 1), plot=True +): + """Test spatial dual batch handler.""" + if log: + init_logger('sup3r', log_level='DEBUG') + + # need to reduce the number of temporal examples to test faster + hr_handler = DataHandlerH5( + FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC( + FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // 2, sample_shape[1] // 2, 1), + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + + dual_handler = DualDataHandler( + hr_handler, lr_handler, s_enhance=2, t_enhance=1, val_split=0.0 + ) + + batch_handler = SpatialDualBatchHandler( + [dual_handler], batch_size=2, s_enhance=2, t_enhance=1, n_batches=10 + ) + + for i, batch in enumerate(batch_handler): + for j, index in enumerate(batch_handler.current_batch_indices): + hr_index = index['hr_index'] + lr_index = index['lr_index'] + + assert np.array_equal( + batch.high_res[j, :, :], + dual_handler.hr_data[hr_index][..., 0, :], + ) + assert np.array_equal( + batch.low_res[j, :, :], + dual_handler.lr_data[lr_index][..., 0, :], + ) + + coarse_lat_lon = spatial_coarsening( + dual_handler.hr_lat_lon[hr_index[:2]], obs_axis=False + ) + lr_lat_lon = dual_handler.lr_lat_lon[lr_index[:2]] + assert np.array_equal(coarse_lat_lon, lr_lat_lon) + + if plot: + for ifeature in range(batch.high_res.shape[-1]): + data_fine = batch.high_res[0, :, :, ifeature] + data_coarse = batch.low_res[0, :, :, ifeature] + fig = plt.figure(figsize=(10, 5)) + ax1 = fig.add_subplot(121) + ax2 = fig.add_subplot(122) + ax1.imshow(data_fine) + ax2.imshow(data_coarse) + plt.savefig(f'./{i}_{ifeature}.png', bbox_inches='tight') + plt.close() + + +def test_validation_batching( + log=False, full_shape=(20, 20), sample_shape=(10, 10, 4) +): + """Test batching of validation data for dual batch handler""" + if log: + init_logger('sup3r', log_level='DEBUG') + + s_enhance = 2 + t_enhance = 2 + + hr_handler = DataHandlerH5( + FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC( + FP_ERA, + FEATURES, + sample_shape=( + sample_shape[0] // s_enhance, + sample_shape[1] // s_enhance, + sample_shape[2] // t_enhance, + ), + temporal_slice=slice(None, None, t_enhance * 10), + worker_kwargs=dict(max_workers=1), + ) + + dual_handler = DualDataHandler( + hr_handler, + lr_handler, + s_enhance=s_enhance, + t_enhance=t_enhance, + val_split=0.1, + ) + + batch_handler = DualBatchHandler( + [dual_handler], + batch_size=2, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=10, + ) + + for batch in batch_handler.val_data: + assert batch.high_res.dtype == np.dtype(np.float32) + assert batch.low_res.dtype == np.dtype(np.float32) + assert batch.low_res.shape[0] == batch.high_res.shape[0] + assert batch.low_res.shape == ( + batch.low_res.shape[0], + sample_shape[0] // s_enhance, + sample_shape[1] // s_enhance, + sample_shape[2] // t_enhance, + len(FEATURES), + ) + assert batch.high_res.shape == ( + batch.high_res.shape[0], + sample_shape[0], + sample_shape[1], + sample_shape[2], + len(FEATURES), + ) + + for j, index in enumerate( + batch_handler.val_data.current_batch_indices + ): + hr_index = index['hr_index'] + lr_index = index['lr_index'] + + assert np.array_equal( + batch.high_res[j], + dual_handler.hr_val_data[hr_index], + ) + assert np.array_equal( + batch.low_res[j], + dual_handler.lr_val_data[lr_index], + ) + + coarse_lat_lon = spatial_coarsening( + dual_handler.hr_lat_lon[hr_index[:2]], obs_axis=False + ) + lr_lat_lon = dual_handler.lr_lat_lon[lr_index[:2]] + + assert np.array_equal(coarse_lat_lon, lr_lat_lon) + + coarse_ti = dual_handler.hr_val_time_index[hr_index[2]][ + ::t_enhance + ] + lr_ti = dual_handler.lr_val_time_index[lr_index[2]] + assert np.array_equal(coarse_ti.values, lr_ti.values) diff --git a/tests/data_handling/test_exo_data_handling.py b/tests/data_handling/test_exo_data_handling.py index 68b01b602..0a147fdd6 100644 --- a/tests/data_handling/test_exo_data_handling.py +++ b/tests/data_handling/test_exo_data_handling.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """pytests for exogenous data handling""" -import shutil import os +import shutil + import numpy as np from sup3r import TEST_DATA_DIR -from sup3r.preprocessing.exogenous_data_handling import ExogenousDataHandler - +from sup3r.preprocessing.data_handling import ExogenousDataHandler FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') diff --git a/tests/training/test_train_gan_lr_era.py b/tests/training/test_train_gan_lr_era.py new file mode 100644 index 000000000..59ef88105 --- /dev/null +++ b/tests/training/test_train_gan_lr_era.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +"""Test the basic training of super resolution GAN with dual data handler""" +import json +import os +import tempfile + +import numpy as np +import pytest +import tensorflow as tf +from rex import init_logger +from tensorflow.python.framework.errors_impl import InvalidArgumentError + +from sup3r import CONFIG_DIR, TEST_DATA_DIR +from sup3r.models import Sup3rGan +from sup3r.preprocessing.data_handling import ( + DataHandlerH5, + DataHandlerNC, + DualDataHandler, +) +from sup3r.preprocessing.dual_batch_handling import ( + DualBatchHandler, + SpatialDualBatchHandler, +) + +FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') +FP_ERA = os.path.join(TEST_DATA_DIR, 'test_era5_co_2012.nc') +TARGET_COORD = (39.01, -105.15) +FEATURES = ['U_100m', 'V_100m'] + + +def test_train_spatial( + log=False, full_shape=(20, 20), sample_shape=(10, 10, 1), n_epoch=2 +): + """Test basic spatial model training with only gen content loss.""" + if log: + init_logger('sup3r', log_level='DEBUG') + + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + + Sup3rGan.seed() + model = Sup3rGan( + fp_gen, fp_disc, learning_rate=2e-5, loss='MeanAbsoluteError' + ) + + # need to reduce the number of temporal examples to test faster + hr_handler = DataHandlerH5( + FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=full_shape, + sample_shape=sample_shape, + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC( + FP_ERA, + FEATURES, + sample_shape=(sample_shape[0] // 2, sample_shape[1] // 2, 1), + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + + dual_handler = DualDataHandler( + hr_handler, lr_handler, s_enhance=2, t_enhance=1, val_split=0.1 + ) + + batch_handler = SpatialDualBatchHandler( + [dual_handler], batch_size=2, s_enhance=2, n_batches=2 + ) + + with tempfile.TemporaryDirectory() as td: + # test that training works and reduces loss + model.train( + batch_handler, + n_epoch=n_epoch, + weight_gen_advers=0.0, + train_gen=True, + train_disc=False, + checkpoint_int=1, + out_dir=os.path.join(td, 'test_{epoch}'), + ) + + assert len(model.history) == n_epoch + vlossg = model.history['val_loss_gen'].values + tlossg = model.history['train_loss_gen'].values + assert np.sum(np.diff(vlossg)) < 0 + assert np.sum(np.diff(tlossg)) < 0 + assert 'test_0' in os.listdir(td) + assert 'test_1' in os.listdir(td) + assert 'model_gen.pkl' in os.listdir(td + '/test_1') + assert 'model_disc.pkl' in os.listdir(td + '/test_1') + + # make an un-trained dummy model + dummy = Sup3rGan( + fp_gen, fp_disc, learning_rate=2e-5, loss='MeanAbsoluteError' + ) + + # test save/load functionality + out_dir = os.path.join(td, 'spatial_gan') + model.save(out_dir) + loaded = model.load(out_dir) + + assert isinstance(dummy.loss_fun, tf.keras.losses.MeanAbsoluteError) + assert isinstance(model.loss_fun, tf.keras.losses.MeanAbsoluteError) + assert isinstance(loaded.loss_fun, tf.keras.losses.MeanAbsoluteError) + + for batch in batch_handler: + out_og = model._tf_generate(batch.low_res) + out_dummy = dummy._tf_generate(batch.low_res) + out_loaded = loaded._tf_generate(batch.low_res) + + # make sure the loaded model generates the same data as the saved + # model but different than the dummy + tf.assert_equal(out_og, out_loaded) + with pytest.raises(InvalidArgumentError): + tf.assert_equal(out_og, out_dummy) + + # make sure the trained model has less loss than dummy + loss_og = model.calc_loss(batch.high_res, out_og)[0] + loss_dummy = dummy.calc_loss(batch.high_res, out_dummy)[0] + assert loss_og.numpy() < loss_dummy.numpy() + + +def test_train_st(n_epoch=3, log=True): + """Test basic spatiotemporal model training with only gen content loss.""" + if log: + init_logger('sup3r', log_level='DEBUG') + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + + Sup3rGan.seed() + model = Sup3rGan( + fp_gen, fp_disc, learning_rate=1e-5, loss='MeanAbsoluteError' + ) + + hr_handler = DataHandlerH5( + FP_WTK, + FEATURES, + target=TARGET_COORD, + shape=(20, 20), + sample_shape=(12, 12, 16), + temporal_slice=slice(None, None, 10), + worker_kwargs=dict(max_workers=1), + ) + lr_handler = DataHandlerNC( + FP_ERA, + FEATURES, + sample_shape=(4, 4, 4), + temporal_slice=slice(None, None, 40), + worker_kwargs=dict(max_workers=1), + ) + + dual_handler = DualDataHandler( + hr_handler, lr_handler, s_enhance=3, t_enhance=4, val_split=0.1 + ) + + batch_handler = DualBatchHandler( + [dual_handler], + batch_size=5, + s_enhance=3, + t_enhance=4, + n_batches=5, + worker_kwargs=dict(max_workers=1), + ) + + assert batch_handler.norm_workers == 1 + assert batch_handler.stats_workers == 1 + assert batch_handler.load_workers == 1 + + with tempfile.TemporaryDirectory() as td: + # test that training works and reduces loss + model.train( + batch_handler, + n_epoch=n_epoch, + weight_gen_advers=0.0, + train_gen=True, + train_disc=False, + checkpoint_int=1, + out_dir=os.path.join(td, 'test_{epoch}'), + ) + + assert 'config_generator' in model.meta + assert 'config_discriminator' in model.meta + assert len(model.history) == n_epoch + assert all(model.history['train_gen_trained_frac'] == 1) + assert all(model.history['train_disc_trained_frac'] == 0) + vlossg = model.history['val_loss_gen'].values + tlossg = model.history['train_loss_gen'].values + assert np.sum(np.diff(vlossg)) < 0 + assert np.sum(np.diff(tlossg)) < 0 + assert 'test_0' in os.listdir(td) + assert 'test_1' in os.listdir(td) + assert 'model_gen.pkl' in os.listdir(td + '/test_1') + assert 'model_disc.pkl' in os.listdir(td + '/test_1') + + # test save/load functionality + out_dir = os.path.join(td, 'st_gan') + model.save(out_dir) + loaded = model.load(out_dir) + + with open(os.path.join(out_dir, 'model_params.json')) as f: + model_params = json.load(f) + + assert np.allclose(model_params['optimizer']['learning_rate'], 1e-5) + assert np.allclose( + model_params['optimizer_disc']['learning_rate'], 1e-5 + ) + assert 'learning_rate_gen' in model.history + assert 'learning_rate_disc' in model.history + + assert 'config_generator' in loaded.meta + assert 'config_discriminator' in loaded.meta + assert model.meta['class'] == 'Sup3rGan' + + # make an un-trained dummy model + dummy = Sup3rGan( + fp_gen, fp_disc, learning_rate=1e-5, loss='MeanAbsoluteError' + ) + + for batch in batch_handler: + out_og = model._tf_generate(batch.low_res) + out_dummy = dummy._tf_generate(batch.low_res) + out_loaded = loaded._tf_generate(batch.low_res) + + # make sure the loaded model generates the same data as the saved + # model but different than the dummy + tf.assert_equal(out_og, out_loaded) + with pytest.raises(InvalidArgumentError): + tf.assert_equal(out_og, out_dummy) + + # make sure the trained model has less loss than dummy + loss_og = model.calc_loss(batch.high_res, out_og)[0] + loss_dummy = dummy.calc_loss(batch.high_res, out_dummy)[0] + assert loss_og.numpy() < loss_dummy.numpy() + + # test that a new shape can be passed through the generator + test_data = np.ones((3, 10, 10, 4, len(FEATURES)), dtype=np.float32) + y_test = model._tf_generate(test_data) + assert y_test.shape[0] == test_data.shape[0] + assert y_test.shape[1] == test_data.shape[1] * 3 + assert y_test.shape[2] == test_data.shape[2] * 3 + assert y_test.shape[3] == test_data.shape[3] * 4 + assert y_test.shape[4] == test_data.shape[4]