diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 6cc1167ca..fa4c0567b 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -960,7 +960,7 @@ def _run_skill_eval(cls, bias_data, base_data, bias_feature, base_dset): out = {} bias_mean = np.nanmean(bias_data) base_mean = np.nanmean(base_data) - out[f'{bias_feature}_bias'] = (bias_mean - base_mean) + out[f'{bias_feature}_bias'] = bias_mean - base_mean out[f'bias_{bias_feature}_mean'] = bias_mean out[f'bias_{bias_feature}_std'] = np.nanstd(bias_data) diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 1fad7d4eb..863e1c920 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -243,8 +243,7 @@ def t_lr_crop_slices(self): """ if self._t_lr_crop_slices is None: self._t_lr_crop_slices = self.get_cropped_slices( - self.t_lr_slices, self.t_lr_pad_slices, 1 - ) + self.t_lr_slices, self.t_lr_pad_slices, 1) return self._t_lr_crop_slices @@ -321,12 +320,10 @@ def s_lr_crop_slices(self): """ if self._s_lr_crop_slices is None: self._s_lr_crop_slices = [] - s1_crop_slices = self.get_cropped_slices( - self.s1_lr_slices, self.s1_lr_pad_slices, 1 - ) - s2_crop_slices = self.get_cropped_slices( - self.s2_lr_slices, self.s2_lr_pad_slices, 1 - ) + s1_crop_slices = self.get_cropped_slices(self.s1_lr_slices, + self.s1_lr_pad_slices, 1) + s2_crop_slices = self.get_cropped_slices(self.s2_lr_slices, + self.s2_lr_pad_slices, 1) for i, _ in enumerate(self.s1_lr_slices): for j, _ in enumerate(self.s2_lr_slices): lr_crop_slice = ( @@ -426,9 +423,9 @@ def s1_lr_slices(self): """List of low resolution spatial slices for first spatial dimension considering padding on all sides of the spatial raster.""" ind = slice(0, self.grid_shape[0]) - slices = get_chunk_slices( - self.grid_shape[0], self.chunk_shape[0], index_slice=ind - ) + slices = get_chunk_slices(self.grid_shape[0], + self.chunk_shape[0], + index_slice=ind) return slices @property @@ -436,9 +433,9 @@ def s2_lr_slices(self): """List of low resolution spatial slices for second spatial dimension considering padding on all sides of the spatial raster.""" ind = slice(0, self.grid_shape[1]) - slices = get_chunk_slices( - self.grid_shape[1], self.chunk_shape[1], index_slice=ind - ) + slices = get_chunk_slices(self.grid_shape[1], + self.chunk_shape[1], + index_slice=ind) return slices @property @@ -725,8 +722,7 @@ def __init__( raster_file = self._input_handler_kwargs.get('raster_file', None) raster_index = self._input_handler_kwargs.get('raster_index', None) temporal_slice = self._input_handler_kwargs.get( - 'temporal_slice', slice(None, None, 1) - ) + 'temporal_slice', slice(None, None, 1)) InputMixIn.__init__( self, target=target, @@ -757,11 +753,9 @@ def __init__( self._handle_features = None self._single_ts_files = self._input_handler_kwargs.get( - 'single_ts_files', None - ) + 'single_ts_files', None) self.cache_pattern = self._input_handler_kwargs.get( - 'cache_pattern', None - ) + 'cache_pattern', None) self.max_workers = self.worker_kwargs.get('max_workers', None) self.output_workers = self.worker_kwargs.get('output_workers', None) self.pass_workers = self.worker_kwargs.get('pass_workers', None) @@ -774,11 +768,9 @@ def __init__( self.model_kwargs = {'model_dir': self.model_kwargs} if model_class is None: - msg = ( - 'Could not load requested model class "{}" from ' - 'sup3r.models, Make sure you typed in the model class ' - 'name correctly.'.format(self.model_class) - ) + msg = ('Could not load requested model class "{}" from ' + 'sup3r.models, Make sure you typed in the model class ' + 'name correctly.'.format(self.model_class)) logger.error(msg) raise KeyError(msg) @@ -813,29 +805,23 @@ def __init__( def preflight(self): """Prelight path name formatting and sanity checks""" - logger.info( - 'Initializing ForwardPassStrategy. ' - f'Using n_nodes={self.nodes} with ' - f'n_spatial_chunks={self.fwp_slicer.n_spatial_chunks}, ' - f'n_temporal_chunks={self.fwp_slicer.n_temporal_chunks}, ' - f'and n_total_chunks={self.chunks}. ' - f'{self.chunks / self.nodes} chunks per node on average.' - ) - logger.info( - f'Using max_workers={self.max_workers}, ' - f'pass_workers={self.pass_workers}, ' - f'output_workers={self.output_workers}' - ) + logger.info('Initializing ForwardPassStrategy. ' + f'Using n_nodes={self.nodes} with ' + f'n_spatial_chunks={self.fwp_slicer.n_spatial_chunks}, ' + f'n_temporal_chunks={self.fwp_slicer.n_temporal_chunks}, ' + f'and n_total_chunks={self.chunks}. ' + f'{self.chunks / self.nodes} chunks per node on average.') + logger.info(f'Using max_workers={self.max_workers}, ' + f'pass_workers={self.pass_workers}, ' + f'output_workers={self.output_workers}') out = self.fwp_slicer.get_temporal_slices() self.ti_slices, self.ti_pad_slices = out - msg = ( - 'Using a padded chunk size ' - f'({self.fwp_chunk_shape[2] + 2 * self.temporal_pad}) ' - f'larger than the full temporal domain ({self.raw_tsteps}). ' - 'Should just run without temporal chunking. ' - ) + msg = ('Using a padded chunk size ' + f'({self.fwp_chunk_shape[2] + 2 * self.temporal_pad}) ' + f'larger than the full temporal domain ({self.raw_tsteps}). ' + 'Should just run without temporal chunking. ') if self.fwp_chunk_shape[2] + 2 * self.temporal_pad >= self.raw_tsteps: logger.warning(msg) warnings.warn(msg) @@ -890,8 +876,7 @@ def handle_features(self): self._handle_features = self.init_handler.handle_features else: hf = self.input_handler_class.get_handle_features( - self.file_paths - ) + self.file_paths) self._handle_features = hf return self._handle_features @@ -902,8 +887,7 @@ def hr_lat_lon(self): logger.info('Getting high-resolution grid for full output domain.') lr_lat_lon = self.lr_lat_lon.copy() self._hr_lat_lon = OutputHandler.get_lat_lon( - lr_lat_lon, self.gids.shape - ) + lr_lat_lon, self.gids.shape) return self._hr_lat_lon def get_full_domain(self, file_paths): @@ -912,9 +896,9 @@ def get_full_domain(self, file_paths): def get_lat_lon(self, file_paths, raster_index, invert_lat=False): """Get lat/lon grid for requested target and shape""" - return self.input_handler_class.get_lat_lon( - file_paths, raster_index, invert_lat=invert_lat - ) + return self.input_handler_class.get_lat_lon(file_paths, + raster_index, + invert_lat=invert_lat) def get_time_index(self, file_paths, max_workers=None, **kwargs): """Get time index for source data using DataHandler.get_time_index @@ -938,9 +922,9 @@ def get_time_index(self, file_paths, max_workers=None, **kwargs): time_index : ndarray Array of time indices for source data """ - return self.input_handler_class.get_time_index( - file_paths, max_workers=max_workers, **kwargs - ) + return self.input_handler_class.get_time_index(file_paths, + max_workers=max_workers, + **kwargs) @property def file_ids(self): @@ -971,8 +955,7 @@ def out_files(self): """ if self._out_files is None: self._out_files = self.get_output_file_names( - out_files=self.out_pattern, file_ids=self.file_ids - ) + out_files=self.out_pattern, file_ids=self.file_ids) return self._out_files @property @@ -1008,8 +991,7 @@ def input_handler_class(self): """ if self._input_handler_class is None: self._input_handler_class = get_input_handler_class( - self.file_paths, self._input_handler_name - ) + self.file_paths, self._input_handler_name) return self._input_handler_class @property @@ -1017,11 +999,8 @@ def max_nodes(self): """Get the maximum number of nodes that this strategy should distribute work to, equal to either the specified max number of nodes or total number of temporal chunks""" - self._max_nodes = ( - self._max_nodes - if self._max_nodes is not None - else self.fwp_slicer.n_temporal_chunks - ) + self._max_nodes = (self._max_nodes if self._max_nodes is not None else + self.fwp_slicer.n_temporal_chunks) return self._max_nodes @staticmethod @@ -1090,29 +1069,23 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.node_index = node_index self.output_data = None - msg = ( - f'Requested forward pass on chunk_index={chunk_index} > ' - f'n_chunks={strategy.chunks}' - ) + msg = (f'Requested forward pass on chunk_index={chunk_index} > ' + f'n_chunks={strategy.chunks}') assert chunk_index <= strategy.chunks, msg - logger.info( - f'Initializing ForwardPass for chunk={chunk_index} ' - f'(temporal_chunk={self.temporal_chunk_index}, ' - f'spatial_chunk={self.spatial_chunk_index}). {self.chunks}' - f' total chunks for the current node.' - ) + logger.info(f'Initializing ForwardPass for chunk={chunk_index} ' + f'(temporal_chunk={self.temporal_chunk_index}, ' + f'spatial_chunk={self.spatial_chunk_index}). {self.chunks}' + f' total chunks for the current node.') self.model_kwargs = self.strategy.model_kwargs self.model_class = self.strategy.model_class model_class = getattr(sup3r.models, self.model_class, None) if model_class is None: - msg = ( - 'Could not load requested model class "{}" from ' - 'sup3r.models, Make sure you typed in the model class ' - 'name correctly.'.format(self.model_class) - ) + msg = ('Could not load requested model class "{}" from ' + 'sup3r.models, Make sure you typed in the model class ' + 'name correctly.'.format(self.model_class)) logger.error(msg) raise KeyError(msg) @@ -1141,9 +1114,7 @@ def __init__(self, strategy, chunk_index=0, node_index=0): ] logger.info( 'Got exogenous_data of length {} with shapes: {}'.format( - len(self.exogenous_data), shapes - ) - ) + len(self.exogenous_data), shapes)) self.input_handler_class = strategy.input_handler_class @@ -1160,17 +1131,14 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.input_data = self.data_handler.data self.input_data = self.bias_correct_source_data( - self.input_data, self.strategy.lr_lat_lon - ) + self.input_data, self.strategy.lr_lat_lon) exo_s_en = self.exo_kwargs.get('s_enhancements', None) - out = self.pad_source_data( - self.input_data, self.pad_width, self.exogenous_data, exo_s_en - ) + out = self.pad_source_data(self.input_data, self.pad_width, + self.exogenous_data, exo_s_en) self.input_data, self.exogenous_data = out - self.unpadded_input_data = self.data_handler.data[ - self.lr_slice[0], self.lr_slice[1] - ] + self.unpadded_input_data = self.data_handler.data[self.lr_slice[0], + self.lr_slice[1]] def update_input_handler_kwargs(self, strategy): """Update the kwargs for the input handler for the current forward pass @@ -1224,8 +1192,7 @@ def ti_crop_slice(self): """Get low-resolution time index crop slice to crop input data time index before getting high-resolution time index""" return self.strategy.fwp_slicer.t_lr_crop_slices[ - self.temporal_chunk_index - ] + self.temporal_chunk_index] @property def lr_times(self): @@ -1247,8 +1214,7 @@ def hr_lat_lon(self): def hr_times(self): """Get high resolution times for the current chunk""" return self.output_handler_class.get_times( - self.lr_times, self.t_enhance * len(self.lr_times) - ) + self.lr_times, self.t_enhance * len(self.lr_times)) @property def chunk_specific_meta(self): @@ -1378,8 +1344,7 @@ def hr_slice(self): def hr_crop_slice(self): """Get hr cropping slice for the current chunk""" hr_crop_slices = self.strategy.fwp_slicer.hr_crop_slices[ - self.temporal_chunk_index - ] + self.temporal_chunk_index] return hr_crop_slices[self.spatial_chunk_index] @property @@ -1404,18 +1369,14 @@ def cache_pattern(self): if cache_pattern is not None: if '{temporal_chunk_index}' not in cache_pattern: cache_pattern = cache_pattern.replace( - '.pkl', '_{temporal_chunk_index}.pkl' - ) + '.pkl', '_{temporal_chunk_index}.pkl') if '{spatial_chunk_index}' not in cache_pattern: cache_pattern = cache_pattern.replace( - '.pkl', '_{spatial_chunk_index}.pkl' - ) + '.pkl', '_{spatial_chunk_index}.pkl') cache_pattern = cache_pattern.replace( - '{temporal_chunk_index}', str(self.temporal_chunk_index) - ) + '{temporal_chunk_index}', str(self.temporal_chunk_index)) cache_pattern = cache_pattern.replace( - '{spatial_chunk_index}', str(self.spatial_chunk_index) - ) + '{spatial_chunk_index}', str(self.spatial_chunk_index)) return cache_pattern @property @@ -1425,11 +1386,9 @@ def raster_file(self): if raster_file is not None: if '{spatial_chunk_index}' not in raster_file: raster_file = raster_file.replace( - '.txt', '_{spatial_chunk_index}.txt' - ) - raster_file = raster_file.replace( - '{spatial_chunk_index}', str(self.spatial_chunk_index) - ) + '.txt', '_{spatial_chunk_index}.txt') + raster_file = raster_file.replace('{spatial_chunk_index}', + str(self.spatial_chunk_index)) return raster_file @property @@ -1446,60 +1405,35 @@ def pad_width(self): ti_start = self.ti_slice.start or 0 ti_stop = self.ti_slice.stop or self.strategy.raw_tsteps pad_t_start = int( - np.maximum(0, (self.strategy.temporal_pad - ti_start)) - ) - pad_t_end = int( - np.maximum( - 0, - ( - self.strategy.temporal_pad - + ti_stop - - self.strategy.raw_tsteps - ), - ) - ) + np.maximum(0, (self.strategy.temporal_pad - ti_start))) + pad_t_end = (self.strategy.temporal_pad + ti_stop + - self.strategy.raw_tsteps) + pad_t_end = int(np.maximum(0, pad_t_end)) s1_start = self.lr_slice[0].start or 0 s1_stop = self.lr_slice[0].stop or self.strategy.grid_shape[0] pad_s1_start = int( - np.maximum(0, (self.strategy.spatial_pad - s1_start)) - ) - pad_s1_end = int( - np.maximum( - 0, - ( - self.strategy.spatial_pad - + s1_stop - - self.strategy.grid_shape[0] - ), - ) - ) + np.maximum(0, (self.strategy.spatial_pad - s1_start))) + pad_s1_end = (self.strategy.spatial_pad + s1_stop + - self.strategy.grid_shape[0]) + pad_s1_end = int(np.maximum(0, pad_s1_end)) s2_start = self.lr_slice[1].start or 0 s2_stop = self.lr_slice[1].stop or self.strategy.grid_shape[1] pad_s2_start = int( - np.maximum(0, (self.strategy.spatial_pad - s2_start)) - ) - pad_s2_end = int( - np.maximum( - 0, - ( - self.strategy.spatial_pad - + s2_stop - - self.strategy.grid_shape[1] - ), - ) - ) - return ( - (pad_s1_start, pad_s1_end), - (pad_s2_start, pad_s2_end), - (pad_t_start, pad_t_end), - ) + np.maximum(0, (self.strategy.spatial_pad - s2_start))) + pad_s2_end = (self.strategy.spatial_pad + s2_stop + - self.strategy.grid_shape[1]) + pad_s2_end = int(np.maximum(0, pad_s2_end)) + return ((pad_s1_start, pad_s1_end), (pad_s2_start, pad_s2_end), + (pad_t_start, pad_t_end)) @staticmethod - def pad_source_data( - input_data, pad_width, exo_data, exo_s_enhancements, mode='reflect' - ): + def pad_source_data(input_data, + pad_width, + exo_data, + exo_s_enhancements, + mode='reflect'): """Pad the edges of the source data from the data handler. Parameters @@ -1537,32 +1471,24 @@ def pad_source_data( """ out = np.pad(input_data, (*pad_width, (0, 0)), mode=mode) - logger.info( - 'Padded input data shape from {} to {} using mode "{}" ' - 'with padding argument: {}'.format( - input_data.shape, out.shape, mode, pad_width - ) - ) + logger.info('Padded input data shape from {} to {} using mode "{}" ' + 'with padding argument: {}'.format(input_data.shape, + out.shape, mode, + pad_width)) if exo_data is not None: for i, i_exo_data in enumerate(exo_data): if i_exo_data is not None: - total_s_enhance = exo_s_enhancements[: i + 1] + total_s_enhance = exo_s_enhancements[:i + 1] total_s_enhance = [ s for s in total_s_enhance if s is not None ] total_s_enhance = np.product(total_s_enhance) - exo_pad_width = ( - ( - total_s_enhance * pad_width[0][0], - total_s_enhance * pad_width[0][1], - ), - ( - total_s_enhance * pad_width[1][0], - total_s_enhance * pad_width[1][1], - ), - (0, 0), - ) + exo_pad_width = ((total_s_enhance * pad_width[0][0], + total_s_enhance * pad_width[0][1]), + (total_s_enhance * pad_width[1][0], + total_s_enhance * pad_width[1][1]), (0, + 0)) exo_data[i] = np.pad(i_exo_data, exo_pad_width, mode=mode) return out, exo_data @@ -1600,16 +1526,12 @@ def bias_correct_source_data(self, data, lat_lon): if 'time_index' in signature(method).parameters: feature_kwargs['time_index'] = self.data_handler.time_index - logger.debug( - 'Bias correcting feature "{}" at axis index {} ' - 'using function: {} with kwargs: {}'.format( - feature, idf, method, feature_kwargs - ) - ) + logger.debug('Bias correcting feature "{}" at axis index {} ' + 'using function: {} with kwargs: {}'.format( + feature, idf, method, feature_kwargs)) - data[..., idf] = method( - data[..., idf], lat_lon, **feature_kwargs - ) + data[..., idf] = method(data[..., idf], lat_lon, + **feature_kwargs) return data @@ -1642,13 +1564,10 @@ def _prep_exogenous_input(self, chunk_shape): chunk_shape[2], arr.shape[-1], ) - msg = ( - 'Target shape for exogenous data in forward pass ' - 'chunk was {}, but something went wrong and i ' - 'resized original data shape from {} to {}'.format( - target_shape, og_shape, arr.shape - ) - ) + msg = ('Target shape for exogenous data in forward pass ' + 'chunk was {}, but something went wrong and i ' + 'resized original data shape from {} to {}'.format( + target_shape, og_shape, arr.shape)) assert arr.shape == target_shape, msg exo_data.append(arr) @@ -1724,37 +1643,26 @@ def _run_generator( hi_res = model.generate(data_chunk, exogenous_data=exo_data) except Exception as e: msg = 'Forward pass failed on chunk with shape {}.'.format( - data_chunk.shape - ) + data_chunk.shape) logger.exception(msg) raise RuntimeError(msg) from e if len(hi_res.shape) == 4: hi_res = np.expand_dims(np.transpose(hi_res, (1, 2, 0, 3)), axis=0) - if ( - s_enhance is not None - and hi_res.shape[1] != s_enhance * data_chunk.shape[i_lr_s] - ): - msg = ( - 'The stated spatial enhancement of {}x did not match ' - 'the low res / high res shapes of {} -> {}'.format( - s_enhance, data_chunk.shape, hi_res.shape - ) - ) + if (s_enhance is not None + and hi_res.shape[1] != s_enhance * data_chunk.shape[i_lr_s]): + msg = ('The stated spatial enhancement of {}x did not match ' + 'the low res / high res shapes of {} -> {}'.format( + s_enhance, data_chunk.shape, hi_res.shape)) logger.error(msg) raise RuntimeError(msg) - if ( - t_enhance is not None - and hi_res.shape[3] != t_enhance * data_chunk.shape[i_lr_t] - ): - msg = ( - 'The stated temporal enhancement of {}x did not match ' - 'the low res / high res shapes of {} -> {}'.format( - t_enhance, data_chunk.shape, hi_res.shape - ) - ) + if (t_enhance is not None + and hi_res.shape[3] != t_enhance * data_chunk.shape[i_lr_t]): + msg = ('The stated temporal enhancement of {}x did not match ' + 'the low res / high res shapes of {} -> {}'.format( + t_enhance, data_chunk.shape, hi_res.shape)) logger.error(msg) raise RuntimeError(msg) @@ -1836,10 +1744,8 @@ def get_node_cmd(cls, config): import_str += 'import time;\n' import_str += 'from sup3r.pipeline import Status;\n' import_str += 'from rex import init_logger;\n' - import_str += ( - 'from sup3r.pipeline.forward_pass ' - f'import ForwardPassStrategy, {cls.__name__};\n' - ) + import_str += ('from sup3r.pipeline.forward_pass ' + f'import ForwardPassStrategy, {cls.__name__};\n') fwps_init_str = get_fun_call_str(ForwardPassStrategy, config) @@ -1850,14 +1756,12 @@ def get_node_cmd(cls, config): if log_file is not None: log_arg_str += f', log_file="{log_file}"' - cmd = ( - f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"strategy = {fwps_init_str};\n" - f"{cls.__name__}.run(strategy, {node_index});\n" - "t_elap = time.time() - t0;\n" - ) + cmd = (f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"strategy = {fwps_init_str};\n" + f"{cls.__name__}.run(strategy, {node_index});\n" + "t_elap = time.time() - t0;\n") cmd = BaseCLI.add_status_cmd(config, ModuleName.FORWARD_PASS, cmd) cmd += ";\'\n" @@ -1907,10 +1811,8 @@ def _single_proc_run(cls, strategy, node_index, chunk_index): returns an initialized forward pass object, otherwise returns None """ fwp = None - check = ( - not strategy.chunk_finished(chunk_index) - and not strategy.failed_chunks - ) + check = (not strategy.chunk_finished(chunk_index) + and not strategy.failed_chunks) if strategy.failed_chunks: msg = 'A forward pass has failed. Aborting all jobs.' @@ -1959,9 +1861,8 @@ def _run_serial(cls, strategy, node_index): """ start = dt.now() - logger.debug( - f'Running forward passes on node {node_index} in ' 'serial.' - ) + logger.debug(f'Running forward passes on node {node_index} in ' + 'serial.') for i, chunk_index in enumerate(strategy.node_chunks[node_index]): now = dt.now() cls._single_proc_run( @@ -1970,20 +1871,16 @@ def _run_serial(cls, strategy, node_index): chunk_index=chunk_index, ) mem = psutil.virtual_memory() - logger.info( - 'Finished forward pass on chunk_index=' - f'{chunk_index} in {dt.now() - now}. {i + 1} of ' - f'{len(strategy.node_chunks[node_index])} ' - 'complete. Current memory usage is ' - f'{mem.used / 1e9:.3f} GB out of ' - f'{mem.total / 1e9:.3f} GB total.' - ) + logger.info('Finished forward pass on chunk_index=' + f'{chunk_index} in {dt.now() - now}. {i + 1} of ' + f'{len(strategy.node_chunks[node_index])} ' + 'complete. Current memory usage is ' + f'{mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.') - logger.info( - 'Finished forward passes on ' - f'{len(strategy.node_chunks[node_index])} chunks in ' - f'{dt.now() - start}' - ) + logger.info('Finished forward passes on ' + f'{len(strategy.node_chunks[node_index])} chunks in ' + f'{dt.now() - start}') @classmethod def _run_parallel(cls, strategy, node_index): @@ -2000,10 +1897,8 @@ def _run_parallel(cls, strategy, node_index): will be run. """ - logger.info( - f'Running parallel forward passes on node {node_index}' - f' with pass_workers={strategy.pass_workers}.' - ) + logger.info(f'Running parallel forward passes on node {node_index}' + f' with pass_workers={strategy.pass_workers}.') futures = {} start = dt.now() @@ -2022,48 +1917,38 @@ def _run_parallel(cls, strategy, node_index): 'start_time': dt.now(), } - logger.info( - f'Started {len(futures)} forward pass runs in ' - f'{dt.now() - now}.' - ) + logger.info(f'Started {len(futures)} forward pass runs in ' + f'{dt.now() - now}.') for i, future in enumerate(as_completed(futures)): try: future.result() mem = psutil.virtual_memory() - msg = ( - 'Finished forward pass on chunk_index=' - f'{futures[future]["chunk_index"]} in ' - f'{dt.now() - futures[future]["start_time"]}. ' - f'{i + 1} of {len(futures)} complete. ' - f'Current memory usage is {mem.used / 1e9:.3f} GB ' - f'out of {mem.total / 1e9:.3f} GB total.' - ) + msg = ('Finished forward pass on chunk_index=' + f'{futures[future]["chunk_index"]} in ' + f'{dt.now() - futures[future]["start_time"]}. ' + f'{i + 1} of {len(futures)} complete. ' + f'Current memory usage is {mem.used / 1e9:.3f} GB ' + f'out of {mem.total / 1e9:.3f} GB total.') logger.info(msg) except Exception as e: - msg = ( - 'Error running forward pass on chunk_index=' - f'{futures[future]["chunk_index"]}.' - ) + msg = ('Error running forward pass on chunk_index=' + f'{futures[future]["chunk_index"]}.') logger.exception(msg) raise RuntimeError(msg) from e - logger.info( - 'Finished asynchronous forward passes on ' - f'{len(strategy.node_chunks[node_index])} chunks in ' - f'{dt.now() - start}' - ) + logger.info('Finished asynchronous forward passes on ' + f'{len(strategy.node_chunks[node_index])} chunks in ' + f'{dt.now() - start}') def run_chunk(self): """Run a forward pass on single spatiotemporal chunk.""" - msg = ( - f'Running forward pass for chunk_index={self.chunk_index}, ' - f'node_index={self.node_index}, file_paths={self.file_paths}. ' - f'Starting forward pass on chunk_shape={self.chunk_shape} with ' - f'spatial_pad={self.strategy.spatial_pad} and temporal_pad=' - f'{self.strategy.temporal_pad}.' - ) + msg = (f'Running forward pass for chunk_index={self.chunk_index}, ' + f'node_index={self.node_index}, file_paths={self.file_paths}. ' + f'Starting forward pass on chunk_shape={self.chunk_shape} with ' + f'spatial_pad={self.strategy.spatial_pad} and temporal_pad=' + f'{self.strategy.temporal_pad}.') logger.info(msg) data_chunk = self.input_data diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index 514a46f15..bb5c15041 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -77,33 +77,31 @@ class DataHandler(FeatureHandler, InputMixIn, TrainingPrepMixIn): 'topography', ) - def __init__( - self, - file_paths, - features, - target=None, - shape=None, - max_delta=20, - temporal_slice=slice(None, None, 1), - hr_spatial_coarsen=None, - time_roll=0, - val_split=0.0, - sample_shape=(10, 10, 1), - raster_file=None, - raster_index=None, - shuffle_time=False, - time_chunk_size=None, - cache_pattern=None, - overwrite_cache=False, - overwrite_ti_cache=False, - load_cached=False, - train_only_features=None, - handle_features=None, - single_ts_files=None, - mask_nan=False, - worker_kwargs=None, - res_kwargs=None, - ): + def __init__(self, + file_paths, + features, + target=None, + shape=None, + max_delta=20, + temporal_slice=slice(None, None, 1), + hr_spatial_coarsen=None, + time_roll=0, + val_split=0.0, + sample_shape=(10, 10, 1), + raster_file=None, + raster_index=None, + shuffle_time=False, + time_chunk_size=None, + cache_pattern=None, + overwrite_cache=False, + overwrite_ti_cache=False, + load_cached=False, + train_only_features=None, + handle_features=None, + single_ts_files=None, + mask_nan=False, + worker_kwargs=None, + res_kwargs=None): """ Parameters ---------- @@ -224,14 +222,12 @@ def __init__( 'chunks': {'south_north': 120, 'west_east': 120}} which then gets passed to xr.open_mfdataset(file, **res_kwargs) """ - InputMixIn.__init__( - self, - target=target, - shape=shape, - raster_file=raster_file, - raster_index=raster_index, - temporal_slice=temporal_slice, - ) + InputMixIn.__init__(self, + target=target, + shape=shape, + raster_file=raster_file, + raster_index=raster_index, + temporal_slice=temporal_slice) self.file_paths = file_paths self.features = ( @@ -269,27 +265,22 @@ def __init__( self._norm_workers = self.worker_kwargs.get('norm_workers', None) self._load_workers = self.worker_kwargs.get('load_workers', None) self._compute_workers = self.worker_kwargs.get('compute_workers', None) - self._worker_attrs = [ - '_ti_workers', - '_norm_workers', - '_compute_workers', - '_extract_workers', - '_load_workers', - ] + self._worker_attrs = ['_ti_workers', + '_norm_workers', + '_compute_workers', + '_extract_workers', + '_load_workers'] self.preflight() - overwrite = ( - self.overwrite_cache - and self.cache_files is not None - and all(os.path.exists(fp) for fp in self.cache_files) - ) + overwrite = (self.overwrite_cache + and self.cache_files is not None + and all(os.path.exists(fp) for fp in self.cache_files)) if self.try_load and self.load_cached: logger.info( f'All {self.cache_files} exist. Loading from cache ' - f'instead of extracting from source files.' - ) + f'instead of extracting from source files.') self.load_cached_data() elif self.try_load and not self.load_cached: @@ -297,14 +288,12 @@ def __init__( logger.info( f'All {self.cache_files} exist. Call ' 'load_cached_data() or use load_cache=True to load ' - 'this data from cache files.' - ) + 'this data from cache files.') else: if overwrite: logger.info( f'{self.cache_files} exists but overwrite_cache ' - 'is set to True. Proceeding with extraction.' - ) + 'is set to True. Proceeding with extraction.') self._raster_size_check() self._run_data_init_if_needed() @@ -331,8 +320,7 @@ def __init__( def try_load(self): """Check if we should try to load cache""" return self._should_load_cache( - self._cache_pattern, self.cache_files, self.overwrite_cache - ) + self._cache_pattern, self.cache_files, self.overwrite_cache) def check_clear_data(self): """Check if data is cached and clear data if not load_cached""" @@ -356,13 +344,11 @@ def _raster_size_check(self): size""" bad_shape = ( self.sample_shape[0] > self.grid_shape[0] - and self.sample_shape[1] > self.grid_shape[1] - ) + and self.sample_shape[1] > self.grid_shape[1]) if bad_shape: msg = ( f'spatial_sample_shape {self.sample_shape[:2]} is ' - f'larger than the raster size {self.grid_shape}' - ) + f'larger than the raster size {self.grid_shape}') logger.warning(msg) warnings.warn(msg) @@ -377,14 +363,11 @@ def _val_split_check(self): msg = ( f'Validation data has shape={self.val_data.shape} ' f'and sample_shape={self.sample_shape}. Use a smaller ' - 'sample_shape and/or larger val_split.' - ) + 'sample_shape and/or larger val_split.') check = any( val_size < samp_size for val_size, samp_size in zip( - self.val_data.shape, self.sample_shape - ) - ) + self.val_data.shape, self.sample_shape)) if check: logger.warning(msg) warnings.warn(msg) @@ -439,8 +422,7 @@ def extract_workers(self): n_procs = len(self.time_chunks) * len(self.extract_features) n_procs = int(np.ceil(n_procs)) extract_workers = estimate_max_workers( - self._extract_workers, proc_mem, n_procs - ) + self._extract_workers, proc_mem, n_procs) return extract_workers @property @@ -450,16 +432,13 @@ def compute_workers(self): proc_mem = int( np.ceil( len(self.extract_features) - / np.maximum(len(self.derive_features), 1) - ) - ) + / np.maximum(len(self.derive_features), 1))) proc_mem *= 4 * self.grid_mem * len(self.time_index) proc_mem /= len(self.time_chunks) n_procs = len(self.time_chunks) * len(self.derive_features) n_procs = int(np.ceil(n_procs)) compute_workers = estimate_max_workers( - self._compute_workers, proc_mem, n_procs - ) + self._compute_workers, proc_mem, n_procs) return compute_workers @property @@ -470,9 +449,8 @@ def load_workers(self): n_procs = 1 if self.cache_files is not None: n_procs = len(self.cache_files) - load_workers = estimate_max_workers( - self._load_workers, proc_mem, n_procs - ) + load_workers = estimate_max_workers(self._load_workers, proc_mem, + n_procs) return load_workers @property @@ -480,8 +458,7 @@ def norm_workers(self): """Get upper bound on workers used for normalization.""" if self.data is not None: norm_workers = estimate_max_workers( - self._norm_workers, 2 * self.feature_mem, self.shape[-1] - ) + self._norm_workers, 2 * self.feature_mem, self.shape[-1]) else: norm_workers = self._norm_workers return norm_workers @@ -530,12 +507,10 @@ def time_chunk_size(self): self._time_chunk_size = self.n_tsteps else: self._time_chunk_size = np.min( - [int(1e9 / step_mem), self.n_tsteps] - ) + [int(1e9 / step_mem), self.n_tsteps]) logger.info( 'time_chunk_size arg not specified. Using ' - f'{self._time_chunk_size}.' - ) + f'{self._time_chunk_size}.') return self._time_chunk_size @property @@ -601,23 +576,17 @@ def noncached_features(self): def extract_features(self): """Features to extract directly from the source handler""" lower_features = [f.lower() for f in self.handle_features] - return [ - f - for f in self.raw_features - if self.lookup(f, 'compute') is None - or Feature.get_basename(f.lower()) in lower_features - ] + return [f for f in self.raw_features + if self.lookup(f, 'compute') is None + or Feature.get_basename(f.lower()) in lower_features] @property def derive_features(self): """List of features which need to be derived from other features""" derive_features = [ - f - for f in set( - list(self.noncached_features) + list(self.extract_features) - ) - if f not in self.extract_features - ] + f for f in set( + list(self.noncached_features) + list(self.extract_features)) + if f not in self.extract_features] return derive_features @property @@ -631,8 +600,7 @@ def raw_features(self): """Get list of features needed for computations""" if self._raw_features is None: self._raw_features = self.get_raw_feature_list( - self.noncached_features, self.handle_features - ) + self.noncached_features, self.handle_features) return self._raw_features @property @@ -643,8 +611,7 @@ def output_features(self): for feature in self.features: ignore = any( fnmatch(feature.lower(), pattern.lower()) - for pattern in self.train_only_features - ) + for pattern in self.train_only_features) if not ignore: out.append(feature) return out @@ -683,9 +650,7 @@ def preflight(self): if len(self.sample_shape) == 2: logger.info( 'Found 2D sample shape of {}. Adding temporal dim of 1'.format( - self.sample_shape - ) - ) + self.sample_shape)) self.sample_shape = (*self.sample_shape, 1) start = self.temporal_slice.start @@ -693,8 +658,7 @@ def preflight(self): n_steps = self.n_tsteps msg = ( f'Temporal slice step ({self.temporal_slice.step}) does not ' - f'evenly divide the number of time steps ({n_steps})' - ) + f'evenly divide the number of time steps ({n_steps})') check = self.temporal_slice.step is None check = check or n_steps % self.temporal_slice.step == 0 if not check: @@ -704,8 +668,7 @@ def preflight(self): msg = ( f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger ' 'than the number of time steps in the raw data ' - f'({len(self.raw_time_index)}).' - ) + f'({len(self.raw_time_index)}).') if len(self.raw_time_index) < self.sample_shape[2]: logger.warning(msg) warnings.warn(msg) @@ -713,15 +676,13 @@ def preflight(self): msg = ( f'The requested time slice {self.temporal_slice} conflicts ' f'with the number of time steps ({len(self.raw_time_index)}) ' - 'in the raw data' - ) + 'in the raw data') t_slice_is_subset = start is not None and stop is not None good_subset = ( t_slice_is_subset and (stop - start <= len(self.raw_time_index)) and stop <= len(self.raw_time_index) - and start <= len(self.raw_time_index) - ) + and start <= len(self.raw_time_index)) if t_slice_is_subset and not good_subset: logger.error(msg) raise RuntimeError(msg) @@ -730,8 +691,7 @@ def preflight(self): f'Initializing DataHandler {self.input_file_info}. ' f'Getting temporal range {self.time_index[0]!s} to ' f'{self.time_index[-1]!s} (inclusive) ' - f'based on temporal_slice {self.temporal_slice}' - ) + f'based on temporal_slice {self.temporal_slice}') logger.info(msg) logger.info( @@ -740,8 +700,7 @@ def preflight(self): f'extract_workers={self.extract_workers}, ' f'compute_workers={self.compute_workers}, ' f'load_workers={self.load_workers}, ' - f'ti_workers={self.ti_workers}' - ) + f'ti_workers={self.ti_workers}') @classmethod def get_lat_lon(cls, file_paths, raster_index, invert_lat=False): @@ -787,8 +746,7 @@ def get_node_cmd(cls, config): f'import {cls.__name__};\n' 'import time;\n' 'from sup3r.pipeline import Status;\n' - 'from rex import init_logger;\n' - ) + 'from rex import init_logger;\n') dh_init_str = get_fun_call_str(cls, config) @@ -810,22 +768,19 @@ def get_node_cmd(cls, config): "t0 = time.time();\n" f"logger = init_logger({log_arg_str});\n" f"data_handler = {dh_init_str};\n" - "t_elap = time.time() - t0;\n" - ) + "t_elap = time.time() - t0;\n") cmd = BaseCLI.add_status_cmd(config, ModuleName.DATA_EXTRACT, cmd) cmd += ";\'\n" return cmd.replace('\\', '/') - def get_cache_file_names( - self, - cache_pattern, - grid_shape=None, - time_index=None, - target=None, - features=None, - ): + def get_cache_file_names(self, + cache_pattern, + grid_shape=None, + time_index=None, + target=None, + features=None): """Get names of cache files from cache_pattern and feature names Parameters @@ -852,8 +807,7 @@ def get_cache_file_names( features = features if features is not None else self.features return self._get_cache_file_names( - cache_pattern, grid_shape, time_index, target, features - ) + cache_pattern, grid_shape, time_index, target, features) def unnormalize(self, means, stds): """Remove normalization from stored means and stds""" @@ -873,8 +827,7 @@ def normalize(self, means, stds): """ max_workers = self.norm_workers self._normalize( - self.data, self.val_data, means, stds, max_workers=max_workers - ) + self.data, self.val_data, means, stds, max_workers=max_workers) def get_next(self): """Get data for observation using random observation index. Loops @@ -887,8 +840,7 @@ def get_next(self): (spatial_1, spatial_2, temporal, features) """ self.current_obs_index = self._get_observation_index( - self.data, self.sample_shape - ) + self.data, self.sample_shape) observation = self.data[self.current_obs_index] return observation @@ -921,8 +873,7 @@ def split_data(self, data=None, val_split=0.0, shuffle_time=False): assert len(self.time_index) == self.data.shape[-2] train_indices, val_indices = self._split_data_indices( - data, val_split=val_split, shuffle_time=shuffle_time - ) + data, val_split=val_split, shuffle_time=shuffle_time) self.val_data = self.data[:, :, val_indices, :] self.data = self.data[:, :, train_indices, :] @@ -963,8 +914,7 @@ def requested_shape(self): shape[0] // self.hr_spatial_coarsen, shape[1] // self.hr_spatial_coarsen, len(self.raw_time_index[self.temporal_slice]), - len(self.features), - ) + len(self.features)) return requested_shape def load_cached_data(self, with_split=True): @@ -985,23 +935,18 @@ def load_cached_data(self, with_split=True): len(self.cache_files), len(self.features), self.features, - self.cache_files, - ) - ) + self.cache_files)) assert len(self.cache_files) == len(self.features), msg - self.data = np.full( - shape=self.requested_shape, fill_value=np.nan, dtype=np.float32 - ) + self.data = np.full(shape=self.requested_shape, fill_value=np.nan, + dtype=np.float32) logger.info(f'Loading cached data from: {self.cache_files}') max_workers = self.load_workers - self._load_cached_data( - data=self.data, - cache_files=self.cache_files, - features=self.features, - max_workers=max_workers, - ) + self._load_cached_data(data=self.data, + cache_files=self.cache_files, + features=self.features, + max_workers=max_workers) self.time_index = self.raw_time_index[self.temporal_slice] @@ -1014,13 +959,11 @@ def load_cached_data(self, with_split=True): logger.debug( 'Splitting data into training / validation sets ' f'({1 - self.val_split}, {self.val_split}) ' - f'for {self.input_file_info}' - ) + f'for {self.input_file_info}') if with_split: self.data, self.val_data = self.split_data( - val_split=self.val_split, shuffle_time=self.shuffle_time - ) + val_split=self.val_split, shuffle_time=self.shuffle_time) def run_all_data_init(self): """Build base 4D data array. Can handle multiple files but assumes @@ -1042,8 +985,7 @@ def run_all_data_init(self): else: n_steps = len(self.raw_time_index[self.temporal_slice]) shifted_time_chunks = get_chunk_slices( - n_steps, self.time_chunk_size - ) + n_steps, self.time_chunk_size) self.run_data_extraction() self.run_data_compute() @@ -1072,11 +1014,8 @@ def run_all_data_init(self): with open(self.cache_files[f_index], 'rb') as fh: self.data[..., f_index] = pickle.load(fh) - logger.info( - 'Finished extracting data for ' - f'{self.input_file_info} in ' - f'{dt.now() - now}' - ) + logger.info(f'Finished extracting data for {self.input_file_info} in ' + f'{dt.now() - now}') return self.data def run_data_extraction(self): @@ -1086,26 +1025,21 @@ def run_data_extraction(self): if self.extract_features: logger.info( f'Starting extraction of {self.extract_features} ' - f'using {len(self.time_chunks)} time_chunks.' - ) + f'using {len(self.time_chunks)} time_chunks.') if self.extract_workers == 1: - self._raw_data = self.serial_extract( - self.file_paths, - self.raster_index, - self.time_chunks, - self.extract_features, - **self.res_kwargs, - ) + self._raw_data = self.serial_extract(self.file_paths, + self.raster_index, + self.time_chunks, + self.extract_features, + **self.res_kwargs) else: - self._raw_data = self.parallel_extract( - self.file_paths, - self.raster_index, - self.time_chunks, - self.extract_features, - self.extract_workers, - **self.res_kwargs, - ) + self._raw_data = self.parallel_extract(self.file_paths, + self.raster_index, + self.time_chunks, + self.extract_features, + self.extract_workers, + **self.res_kwargs) logger.info( f'Finished extracting {self.extract_features} for ' @@ -1120,32 +1054,27 @@ def run_data_compute(self): logger.info(f'Starting computation of {self.derive_features}') if self.compute_workers == 1: - self._raw_data = self.serial_compute( - self._raw_data, - self.file_paths, - self.raster_index, - self.time_chunks, - self.derive_features, - self.noncached_features, - self.handle_features, - ) + self._raw_data = self.serial_compute(self._raw_data, + self.file_paths, + self.raster_index, + self.time_chunks, + self.derive_features, + self.noncached_features, + self.handle_features) elif self.compute_workers != 1: - self._raw_data = self.parallel_compute( - self._raw_data, - self.file_paths, - self.raster_index, - self.time_chunks, - self.derive_features, - self.noncached_features, - self.handle_features, - self.compute_workers, - ) + self._raw_data = self.parallel_compute(self._raw_data, + self.file_paths, + self.raster_index, + self.time_chunks, + self.derive_features, + self.noncached_features, + self.handle_features, + self.compute_workers) logger.info( f'Finished computing {self.derive_features} for ' - f'{self.input_file_info}' - ) + f'{self.input_file_info}') def data_fill(self, t, t_slice, f_index, f): """Place single extracted / computed chunk in final data array @@ -1183,8 +1112,7 @@ def serial_data_fill(self, shifted_time_chunks): if t % interval == 0: logger.info( f'Added {t + 1} of {len(shifted_time_chunks)} ' - 'chunks to final data array' - ) + 'chunks to final data array') self._raw_data.pop(t) def parallel_data_fill(self, shifted_time_chunks, max_workers=None): @@ -1200,15 +1128,11 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None): max available workers will be used. If 1 cached data will be loaded in serial """ - self.data = np.zeros( - ( - self.grid_shape[0], - self.grid_shape[1], - self.n_tsteps, - len(self.features), - ), - dtype=np.float32, - ) + self.data = np.zeros((self.grid_shape[0], + self.grid_shape[1], + self.n_tsteps, + len(self.features)), + dtype=np.float32) if max_workers == 1: self.serial_data_fill(shifted_time_chunks) @@ -1224,8 +1148,7 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None): logger.info( f'Started adding {len(futures)} chunks ' - f'to data array in {dt.now() - now}.' - ) + f'to data array in {dt.now() - now}.') interval = int(np.ceil(len(futures) / 10)) for i, future in enumerate(as_completed(futures)): @@ -1235,15 +1158,13 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None): msg = ( f'Error adding ({futures[future]["t"]}, ' f'{futures[future]["fidx"]}) chunk to ' - 'final data array.' - ) + 'final data array.') logger.exception(msg) raise RuntimeError(msg) from e if i % interval == 0: logger.debug( f'Added {i+1} out of {len(futures)} ' - 'chunks to final data array' - ) + 'chunks to final data array') logger.info('Finished building data array') @abstractmethod @@ -1290,8 +1211,7 @@ def lin_bc(self, bc_files, threshold=0.1): lat_lon=self.lat_lon, feature_name=feature, bias_fp=fp, - threshold=threshold, - ) + threshold=threshold) if scalar.shape[-1] == 1: scalar = np.repeat(scalar, self.shape[2], axis=2) @@ -1305,17 +1225,14 @@ def lin_bc(self, bc_files, threshold=0.1): 'Can only accept bias correction factors ' 'with last dim equal to 1 or 12 but ' 'received bias correction factors with ' - 'shape {}'.format(scalar.shape) - ) + 'shape {}'.format(scalar.shape)) logger.error(msg) raise RuntimeError(msg) logger.info( 'Bias correcting "{}" with linear ' 'correction from "{}"'.format( - feature, os.path.basename(fp) - ) - ) + feature, os.path.basename(fp))) self.data[..., idf] *= scalar self.data[..., idf] += adder completed.append(feature) @@ -1325,9 +1242,8 @@ def lin_bc(self, bc_files, threshold=0.1): class DataHandlerDC(DataHandler): """Data-centric data handler""" - def get_observation_index( - self, temporal_weights=None, spatial_weights=None - ): + def get_observation_index(self, temporal_weights=None, + spatial_weights=None): """Randomly gets weighted spatial sample and time sample Parameters @@ -1347,24 +1263,19 @@ def get_observation_index( """ if spatial_weights is not None: spatial_slice = weighted_box_sampler( - self.data, self.sample_shape[:2], weights=spatial_weights - ) + self.data, self.sample_shape[:2], weights=spatial_weights) else: spatial_slice = uniform_box_sampler( - self.data, self.sample_shape[:2] - ) + self.data, self.sample_shape[:2]) if temporal_weights is not None: temporal_slice = weighted_time_sampler( - self.data, self.sample_shape[2], weights=temporal_weights - ) + self.data, self.sample_shape[2], weights=temporal_weights) else: temporal_slice = uniform_time_sampler( - self.data, self.sample_shape[2] - ) + self.data, self.sample_shape[2]) return tuple( - [*spatial_slice, temporal_slice, np.arange(len(self.features))] - ) + [*spatial_slice, temporal_slice, np.arange(len(self.features))]) def get_next(self, temporal_weights=None, spatial_weights=None): """Get data for observation using weighted random observation index. @@ -1386,8 +1297,7 @@ def get_next(self, temporal_weights=None, spatial_weights=None): (spatial_1, spatial_2, temporal, features) """ self.current_obs_index = self.get_observation_index( - temporal_weights=temporal_weights, spatial_weights=spatial_weights - ) + temporal_weights=temporal_weights, spatial_weights=spatial_weights) observation = self.data[self.current_obs_index] return observation @@ -1423,8 +1333,7 @@ def get_file_times(cls, file_paths, **kwargs): else: msg = ( f'Could not get time_index for {file_paths}. ' - 'Assuming time independence.' - ) + 'Assuming time independence.') time_index = None logger.warning(msg) warnings.warn(msg) @@ -1452,11 +1361,9 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): time_index : pd.Datetimeindex List of times as a Datetimeindex """ - max_workers = ( - len(file_paths) - if max_workers is None - else np.min((max_workers, len(file_paths))) - ) + max_workers = (len(file_paths) + if max_workers is None + else np.min((max_workers, len(file_paths)))) if max_workers == 1: return cls.get_file_times(file_paths, **kwargs) ti = {} @@ -1469,8 +1376,7 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): logger.info( f'Started building time index from {len(file_paths)} ' - f'files in {dt.now() - now}.' - ) + f'files in {dt.now() - now}.') for i, future in enumerate(as_completed(futures)): try: @@ -1478,10 +1384,8 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): if val is not None: ti[futures[future]['idx']] = list(val) except Exception as e: - msg = ( - 'Error while getting time index from file ' - f'{futures[future]["file"]}.' - ) + msg = ('Error while getting time index from file ' + f'{futures[future]["file"]}.') logger.exception(msg) raise RuntimeError(msg) from e logger.debug(f'Stored {i+1} out of {len(futures)} file times') @@ -1497,34 +1401,30 @@ def feature_registry(cls): dict Method registry """ - registry = { - 'BVF2_(.*)m': BVFreqSquaredNC, - 'BVF_MO_(.*)m': BVFreqMon, - 'RMOL': InverseMonNC, - 'U_(.*)': UWind, - 'V_(.*)': VWind, - 'Windspeed_(.*)m': WindspeedNC, - 'Winddirection_(.*)m': WinddirectionNC, - 'lat_lon': LatLonNC, - 'Shear_(.*)m': Shear, - 'REWS_(.*)m': Rews, - 'Temperature_(.*)m': TempNC, - 'Pressure_(.*)m': PressureNC, - 'PotentialTemp_(.*)m': PotentialTempNC, - 'PT_(.*)m': PotentialTempNC, - 'topography': ['HGT', 'orog'], - } + registry = {'BVF2_(.*)m': BVFreqSquaredNC, + 'BVF_MO_(.*)m': BVFreqMon, + 'RMOL': InverseMonNC, + 'U_(.*)': UWind, + 'V_(.*)': VWind, + 'Windspeed_(.*)m': WindspeedNC, + 'Winddirection_(.*)m': WinddirectionNC, + 'lat_lon': LatLonNC, + 'Shear_(.*)m': Shear, + 'REWS_(.*)m': Rews, + 'Temperature_(.*)m': TempNC, + 'Pressure_(.*)m': PressureNC, + 'PotentialTemp_(.*)m': PotentialTempNC, + 'PT_(.*)m': PotentialTempNC, + 'topography': ['HGT', 'orog']} return registry @classmethod - def extract_feature( - cls, - file_paths, - raster_index, - feature, - time_slice=slice(None), - **kwargs, - ): + def extract_feature(cls, + file_paths, + raster_index, + feature, + time_slice=slice(None), + **kwargs): """Extract single feature from data source. The requested feature can match exactly to one found in the source data or can have a matching prefix with a suffix specifying the height or pressure level @@ -1566,8 +1466,7 @@ def extract_feature( if feature in handle or feature.lower() in handle: feat_key = feature if feature in handle else feature.lower() fdata = cls.direct_extract( - handle, feat_key, raster_index, time_slice - ) + handle, feat_key, raster_index, time_slice) elif basename in handle or basename.lower() in handle: feat_key = basename if basename in handle else basename.lower() @@ -1577,16 +1476,14 @@ def extract_feature( feat_key, raster_index, np.float32(interp_height), - time_slice, - ) + time_slice) elif interp_pressure is not None: fdata = Interpolator.interp_var_to_pressure( handle, feat_key, raster_index, np.float32(interp_pressure), - time_slice, - ) + time_slice) else: msg = f'{feature} cannot be extracted from source data.' @@ -1672,13 +1569,11 @@ def get_closest_lat_lon(lat_lon, target): """ # shape of ll2 is (n, 2) where axis=1 is (lat, lon) ll2 = np.vstack( - (lat_lon[..., 0].flatten(), lat_lon[..., 1].flatten()) - ).T + (lat_lon[..., 0].flatten(), lat_lon[..., 1].flatten())).T tree = KDTree(ll2) _, i = tree.query(np.array(target)) row, col = np.where( - (lat_lon[..., 0] == ll2[i, 0]) & (lat_lon[..., 1] == ll2[i, 1]) - ) + (lat_lon[..., 0] == ll2[i, 0]) & (lat_lon[..., 1] == ll2[i, 1])) row = row[0] col = col[0] return row, col @@ -1721,10 +1616,8 @@ def compute_raster_index(cls, file_paths, target, grid_shape): else: row_end = row + grid_shape[0] row_start = row - raster_index = [ - slice(row_start, row_end), - slice(col, col + grid_shape[1]), - ] + raster_index = [slice(row_start, row_end), + slice(col, col + grid_shape[1])] cls._validate_raster_shape(target, grid_shape, lat_lon, raster_index) return raster_index @@ -1750,18 +1643,15 @@ def _check_grid_extent(cls, target, grid_shape, lat_lon): max_lon = np.max(lat_lon[..., 1]) logger.debug( 'Calculating raster index from WRF file ' - f'for shape {grid_shape} and target {target}' - ) + f'for shape {grid_shape} and target {target}') logger.debug( - f'lat/lon (min, max): {min_lat}/{min_lon}, ' f'{max_lat}/{max_lon}' - ) + f'lat/lon (min, max): {min_lat}/{min_lon}, ' + f'{max_lat}/{max_lon}') msg = ( f'target {target} out of bounds with min lat/lon ' - f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}' - ) - assert ( - min_lat <= target[0] <= max_lat and min_lon <= target[1] <= max_lon - ), msg + f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}') + assert (min_lat <= target[0] <= max_lat + and min_lon <= target[1] <= max_lon), msg @classmethod def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): @@ -1781,20 +1671,18 @@ def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): raster_index : list List of slices selecting region from entire available grid. """ - if ( - raster_index[0].stop > lat_lon.shape[0] - or raster_index[1].stop > lat_lon.shape[1] - or raster_index[0].start < 0 - or raster_index[1].start < 0 - ): + check = (raster_index[0].stop > lat_lon.shape[0] + or raster_index[1].stop > lat_lon.shape[1] + or raster_index[0].start < 0 + or raster_index[1].start < 0) + if check: msg = ( f'Invalid target {target}, shape {grid_shape}, and raster ' f'{raster_index} for data domain of size ' f'{lat_lon.shape[:-1]} with lower left corner ' f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' f' and upper right corner ({np.max(lat_lon[..., 0])}, ' - f'{np.max(lat_lon[..., 1])}).' - ) + f'{np.max(lat_lon[..., 1])}).') raise ValueError(msg) def get_raster_index(self): @@ -1810,30 +1698,24 @@ def get_raster_index(self): self.raster_file = ( self.raster_file if self.raster_file is None - else self.raster_file.replace('.txt', '.npy') - ) + else self.raster_file.replace('.txt', '.npy')) if self.raster_file is not None and os.path.exists(self.raster_file): logger.debug( f'Loading raster index: {self.raster_file} ' - f'for {self.input_file_info}' - ) + f'for {self.input_file_info}') raster_index = np.load(self.raster_file, allow_pickle=True) raster_index = list(raster_index) else: check = self.grid_shape is not None and self.target is not None msg = ( 'Must provide raster file or shape + target to get ' - 'raster index' - ) + 'raster index') assert check, msg raster_index = self.compute_raster_index( - self.file_paths, self.target, self.grid_shape - ) + self.file_paths, self.target, self.grid_shape) logger.debug( 'Found raster index with row, col slices: {}'.format( - raster_index - ) - ) + raster_index)) if self.raster_file is not None: basedir = os.path.dirname(self.raster_file) diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 205b25480..70bbd4af8 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -216,8 +216,7 @@ def _run_pair_checks(self, hr_handler, lr_handler): hr_shape = self.hr_data.shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, - hr_shape[2] // self.t_enhance, - hr_shape[3]) + hr_shape[2] // self.t_enhance, hr_shape[3]) msg = (f'hr_data.shape {self.hr_data.shape} and ' f'lr_data.shape {self.lr_data.shape} are ' f'incompatible. Must be {hr_shape} and {lr_shape}.') @@ -227,8 +226,7 @@ def _run_pair_checks(self, hr_handler, lr_handler): hr_shape = self.hr_val_data.shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, - hr_shape[2] // self.t_enhance, - hr_shape[3]) + hr_shape[2] // self.t_enhance, hr_shape[3]) msg = (f'hr_val_data.shape {self.hr_val_data.shape} ' f'and lr_val_data.shape {self.lr_val_data.shape}' f' are incompatible. Must be {hr_shape} and {lr_shape}.') @@ -340,8 +338,8 @@ def lr_lat_lon(self, lat_lon): def hr_lat_lon(self): """Get high_res lat lon array""" if self._hr_lat_lon is None: - self._hr_lat_lon = self.hr_dh.lat_lon[:self.hr_required_shape[0], - :self.hr_required_shape[1]] + self._hr_lat_lon = self.hr_dh.lat_lon[:self.hr_required_shape[0], : + self.hr_required_shape[1]] return self._hr_lat_lon @hr_lat_lon.setter @@ -440,8 +438,8 @@ def regrid_lr_data(self): out = [] for i in range(len(self.features)): - tmp = regridder(self.lr_input_data[..., i]).reshape( - self.lr_required_shape)[..., np.newaixs] + tmp = regridder(self.lr_input_data[..., i]) + tmp = tmp.reshape(self.lr_required_shape)[..., np.newaxis] out.append(tmp) return np.concatenate(out, axis=-1) @@ -470,6 +468,8 @@ def get_next(self): slice(s.start * self.t_enhance, s.stop * self.t_enhance)) hr_obs_idx.append(lr_obs_idx[-1]) hr_obs_idx = tuple(hr_obs_idx) - self.current_obs_index = {'hr_index': hr_obs_idx, - 'lr_index': lr_obs_idx} + self.current_obs_index = { + 'hr_index': hr_obs_idx, + 'lr_index': lr_obs_idx + } return self.hr_data[hr_obs_idx], self.lr_data[lr_obs_idx] diff --git a/sup3r/preprocessing/data_handling/nc_data_handling.py b/sup3r/preprocessing/data_handling/nc_data_handling.py index 744039d9b..d655f4c5f 100644 --- a/sup3r/preprocessing/data_handling/nc_data_handling.py +++ b/sup3r/preprocessing/data_handling/nc_data_handling.py @@ -7,6 +7,7 @@ import warnings from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime as dt +from typing import ClassVar import numpy as np import pandas as pd @@ -17,33 +18,17 @@ from scipy.stats import mode from sup3r.preprocessing.data_handling.base import DataHandler, DataHandlerDC -from sup3r.preprocessing.feature_handling import ( - BVFreqMon, - BVFreqSquaredNC, - ClearSkyRatioCC, - Feature, - InverseMonNC, - LatLonNC, - PotentialTempNC, - PressureNC, - Rews, - Shear, - Tas, - TasMax, - TasMin, - TempNC, - TempNCforCC, - UWind, - VWind, - WinddirectionNC, - WindspeedNC, -) +from sup3r.preprocessing.feature_handling import (BVFreqMon, BVFreqSquaredNC, + ClearSkyRatioCC, Feature, + InverseMonNC, LatLonNC, + PotentialTempNC, PressureNC, + Rews, Shear, Tas, TasMax, + TasMin, TempNC, TempNCforCC, + UWind, VWind, + WinddirectionNC, WindspeedNC) from sup3r.utilities.interpolation import Interpolator -from sup3r.utilities.utilities import ( - estimate_max_workers, - get_time_dim_name, - np_to_pd_times, -) +from sup3r.utilities.utilities import (estimate_max_workers, get_time_dim_name, + np_to_pd_times) np.random.seed(42) @@ -53,7 +38,7 @@ class DataHandlerNC(DataHandler): """Data Handler for NETCDF data""" - CHUNKS = { + CHUNKS: ClassVar[dict] = { 'XTIME': 100, 'XLAT': 150, 'XLON': 150, diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 214b265f2..de0080ebf 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -39,11 +39,9 @@ class EraDownloader: """Class to handle ERA5 downloading, variable renaming, file combination, and interpolation.""" - msg = ( - 'To download ERA5 data you need to have a ~/.cdsapirc file ' - 'with a valid url and api key. Follow the instructions here: ' - 'https://cds.climate.copernicus.eu/api-how-to' - ) + msg = ('To download ERA5 data you need to have a ~/.cdsapirc file ' + 'with a valid url and api key. Follow the instructions here: ' + 'https://cds.climate.copernicus.eu/api-how-to') req_file = os.path.join(os.path.expanduser('~'), '.cdsapirc') assert os.path.exists(req_file), msg @@ -124,19 +122,17 @@ class EraDownloader: 'tp': 'total_precip', } - def __init__( - self, - year, - month, - area, - levels, - combined_out_pattern, - interp_out_pattern=None, - run_interp=True, - overwrite=False, - required_shape=None, - variables=None, - ): + def __init__(self, + year, + month, + area, + levels, + combined_out_pattern, + interp_out_pattern=None, + run_interp=True, + overwrite=False, + required_shape=None, + variables=None): """Initialize the class. Parameters @@ -187,11 +183,9 @@ def __init__( self.check_good_vars(self.variables) self.prep_var_lists(self.variables) - msg = ( - 'Initialized EraDownloader with: ' - f'year={self.year}, month={self.month}, area={self.area}, ' - f'levels={self.levels}, variables={self.variables}' - ) + msg = ('Initialized EraDownloader with: ' + f'year={self.year}, month={self.month}, area={self.area}, ' + f'levels={self.levels}, variables={self.variables}') logger.info(msg) @property @@ -206,7 +200,8 @@ def days(self): """Get list of days for the requested month""" return [ str(n).zfill(2) - for n in np.arange(1, monthrange(self.year, self.month)[1] + 1) + for n in np.arange(1, + monthrange(self.year, self.month)[1] + 1) ] @property @@ -215,8 +210,7 @@ def interp_file(self): if self._interp_file is None: if self.interp_out_pattern is not None and self.run_interp: self._interp_file = self.interp_out_pattern.format( - year=self.year, month=str(self.month).zfill(2) - ) + year=self.year, month=str(self.month).zfill(2)) os.makedirs(os.path.dirname(self._interp_file), exist_ok=True) return self._interp_file @@ -225,8 +219,7 @@ def combined_file(self): """Get name of file from combined surface and level files""" if self._combined_file is None: self._combined_file = self.combined_out_pattern.format( - year=self.year, month=str(self.month).zfill(2) - ) + year=self.year, month=str(self.month).zfill(2)) os.makedirs(os.path.dirname(self._combined_file), exist_ok=True) return self._combined_file @@ -299,10 +292,8 @@ def check_good_vars(self, variables): """ good = all(var in self.VALID_VARIABLES for var in variables) if not good: - msg = ( - f'Received variables {variables} not in valid variables ' - f'list {self.VALID_VARIABLES}' - ) + msg = (f'Received variables {variables} not in valid variables ' + f'list {self.VALID_VARIABLES}') logger.error(msg) raise OSError(msg) @@ -327,9 +318,8 @@ def prep_var_lists(self, variables): for var in variables: if var in self.SFC_VARS and var not in self.sfc_file_variables: self.sfc_file_variables.append(var) - elif ( - var in self.LEVEL_VARS and var not in self.level_file_variables - ): + elif (var in self.LEVEL_VARS + and var not in self.level_file_variables): self.level_file_variables.append(var) else: msg = f'Requested {var} is not available for download.' @@ -339,14 +329,11 @@ def prep_var_lists(self, variables): def download_process_combine(self): """Run the download routine.""" sfc_check = len(self.sfc_file_variables) > 0 - level_check = ( - len(self.level_file_variables) > 0 and self.levels is not None - ) + level_check = (len(self.level_file_variables) > 0 + and self.levels is not None) if self.level_file_variables: - msg = ( - f'{self.level_file_variables} requested but no levels' - ' were provided.' - ) + msg = (f'{self.level_file_variables} requested but no levels' + ' were provided.') if self.levels is None: logger.warning(msg) warn(msg) @@ -360,14 +347,11 @@ def download_process_combine(self): def download_levels_file(self): """Download file with requested pressure levels""" if not os.path.exists(self.level_file) or self.overwrite: - msg = ( - f'Downloading {self.level_file_variables} to ' - f'{self.level_file}.' - ) + msg = (f'Downloading {self.level_file_variables} to ' + f'{self.level_file}.') logger.info(msg) CDS_API_CLIENT.retrieve( - 'reanalysis-era5-pressure-levels', - { + 'reanalysis-era5-pressure-levels', { 'product_type': 'reanalysis', 'format': 'netcdf', 'variable': self.level_file_variables, @@ -377,23 +361,18 @@ def download_levels_file(self): 'day': self.days, 'time': self.hours, 'area': self.area, - }, - self.level_file, - ) + }, self.level_file) else: logger.info(f'File already exists: {self.level_file}.') def download_surface_file(self): """Download surface file""" if not os.path.exists(self.surface_file) or self.overwrite: - msg = ( - f'Downloading {self.sfc_file_variables} to ' - f'{self.surface_file}.' - ) + msg = (f'Downloading {self.sfc_file_variables} to ' + f'{self.surface_file}.') logger.info(msg) CDS_API_CLIENT.retrieve( - 'reanalysis-era5-single-levels', - { + 'reanalysis-era5-single-levels', { 'product_type': 'reanalysis', 'format': 'netcdf', 'variable': self.sfc_file_variables, @@ -402,9 +381,7 @@ def download_surface_file(self): 'day': self.days, 'time': self.hours, 'area': self.area, - }, - self.surface_file, - ) + }, self.surface_file) else: logger.info(f'File already exists: {self.surface_file}.') @@ -421,10 +398,8 @@ def process_surface_file(self): ds = self.map_vars(old_ds, ds) os.system(f'mv {tmp_file} {self.surface_file}') - logger.info( - f'Finished processing {self.surface_file}. Moved ' - f'{tmp_file} to {self.surface_file}.' - ) + logger.info(f'Finished processing {self.surface_file}. Moved ' + f'{tmp_file} to {self.surface_file}.') def map_vars(self, old_ds, ds): """Map variables from old dataset to new dataset @@ -475,9 +450,9 @@ def convert_z(self, standard_name, long_name, old_ds, ds): Dataset() object for new file with new height variable written. """ - _ = ds.createVariable( - standard_name, np.float32, dimensions=old_ds['z'].dimensions - ) + _ = ds.createVariable(standard_name, + np.float32, + dimensions=old_ds['z'].dimensions) ds.variables[standard_name][:] = old_ds['z'][:] / 9.81 ds.variables[standard_name].long_name = long_name ds.variables[standard_name].standard_name = 'zg' @@ -501,18 +476,16 @@ def process_level_file(self): tmp = np.zeros(ds.variables['zg'].shape) for i in range(tmp.shape[1]): tmp[:, i, :, :] = ds.variables['level'][i] * 100 - _ = ds.createVariable( - 'pressure', np.float32, dimensions=dims - ) + _ = ds.createVariable('pressure', + np.float32, + dimensions=dims) ds.variables['pressure'][:] = tmp[...] ds.variables['pressure'].long_name = 'Pressure' ds.variables['pressure'].units = 'Pa' os.system(f'mv {tmp_file} {self.level_file}') - logger.info( - f'Finished processing {self.level_file}. Moved ' - f'{tmp_file} to {self.level_file}.' - ) + logger.info(f'Finished processing {self.level_file}. Moved ' + f'{tmp_file} to {self.level_file}.') def process_and_combine(self): """Process variables and combine.""" @@ -528,10 +501,8 @@ def process_and_combine(self): self.process_surface_file() files.append(self.surface_file) - logger.info( - f'Combining {files} and {self.surface_file} ' - f'to {self.combined_file}.' - ) + logger.info(f'Combining {files} and {self.surface_file} ' + f'to {self.combined_file}.') with xr.open_mfdataset(files) as ds: ds.to_netcdf(self.combined_file) logger.info(f'Finished writing {self.combined_file}') @@ -556,12 +527,10 @@ def good_file(self, file, required_shape): bool Whether or not data has required shape and variables. """ - out = self.check_single_file( - file, - check_nans=False, - check_heights=False, - required_shape=required_shape, - ) + out = self.check_single_file(file, + check_nans=False, + check_heights=False, + required_shape=required_shape) good_vars, good_shape, _, _ = out check = good_vars and good_shape return check @@ -582,31 +551,26 @@ def check_existing_files(self): os.remove(self.level_file) if os.path.exists(self.surface_file): os.remove(self.surface_file) - logger.info( - f'{self.combined_file} already exists and ' - f'overwrite={self.overwrite}. Skipping.' - ) + logger.info(f'{self.combined_file} already exists and ' + f'overwrite={self.overwrite}. Skipping.') except Exception as e: logger.info(f'Something wrong with {self.combined_file}. {e}') if os.path.exists(self.combined_file): os.remove(self.combined_file) check = self.interp_file is not None and os.path.exists( - self.interp_file - ) + self.interp_file) if check: os.remove(self.interp_file) def run_interpolation(self, max_workers=None, **kwargs): """Run interpolation to get final final. Runs log interpolation up to max_log_height (usually 100m) and linear interpolation above this.""" - LogLinInterpolator.run( - infile=self.combined_file, - outfile=self.interp_file, - max_workers=max_workers, - variables=self.variables, - overwrite=self.overwrite, - **kwargs, - ) + LogLinInterpolator.run(infile=self.combined_file, + outfile=self.interp_file, + max_workers=max_workers, + variables=self.variables, + overwrite=self.overwrite, + **kwargs) def get_monthly_file(self, interp_workers=None, **interp_kwargs): """Download level and surface files, process variables, and combine @@ -649,10 +613,8 @@ def all_months_exist(cls, year, file_pattern): """ return all( os.path.exists( - file_pattern.format(year=year, month=str(month).zfill(2)) - ) - for month in range(1, 13) - ) + file_pattern.format(year=year, month=str(month).zfill(2))) + for month in range(1, 13)) @classmethod def already_pruned(cls, infile): @@ -674,16 +636,14 @@ def prune_output(cls, infile): tmp_file = cls.get_tmp_file(infile) with Dataset(infile, 'r') as old_ds: with Dataset(tmp_file, 'w') as new_ds: - new_ds = cls.init_dims( - old_ds, new_ds, ('time', 'latitude', 'longitude') - ) + new_ds = cls.init_dims(old_ds, new_ds, + ('time', 'latitude', 'longitude')) for var in old_ds.variables: if any(name in var for name in cls.KEEP_VARIABLES): old_var = old_ds[var] vals = old_var[:] _ = new_ds.createVariable( - var, old_var.dtype, dimensions=old_var.dimensions - ) + var, old_var.dtype, dimensions=old_var.dimensions) new_ds[var][:] = vals if hasattr(old_var, 'units'): new_ds[var].units = old_var.units @@ -693,27 +653,23 @@ def prune_output(cls, infile): if hasattr(old_var, 'long_name'): new_ds[var].long_name = old_var.long_name os.system(f'mv {tmp_file} {infile}') - logger.info( - f'Finished pruning variables in {infile}. Moved ' - f'{tmp_file} to {infile}.' - ) + logger.info(f'Finished pruning variables in {infile}. Moved ' + f'{tmp_file} to {infile}.') @classmethod - def run_month( - cls, - year, - month, - area, - levels, - combined_out_pattern, - interp_out_pattern=None, - run_interp=True, - overwrite=False, - required_shape=None, - interp_workers=None, - variables=None, - **interp_kwargs, - ): + def run_month(cls, + year, + month, + area, + levels, + combined_out_pattern, + interp_out_pattern=None, + run_interp=True, + overwrite=False, + required_shape=None, + interp_workers=None, + variables=None, + **interp_kwargs): """Run routine for all months in the requested year. Parameters @@ -749,40 +705,35 @@ def run_month( **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ - downloader = cls( - year=year, - month=month, - area=area, - levels=levels, - combined_out_pattern=combined_out_pattern, - interp_out_pattern=interp_out_pattern, - run_interp=run_interp, - overwrite=overwrite, - required_shape=required_shape, - variables=variables, - ) - downloader.get_monthly_file( - interp_workers=interp_workers, **interp_kwargs - ) + downloader = cls(year=year, + month=month, + area=area, + levels=levels, + combined_out_pattern=combined_out_pattern, + interp_out_pattern=interp_out_pattern, + run_interp=run_interp, + overwrite=overwrite, + required_shape=required_shape, + variables=variables) + downloader.get_monthly_file(interp_workers=interp_workers, + **interp_kwargs) @classmethod - def run_year( - cls, - year, - area, - levels, - combined_out_pattern, - combined_yearly_file, - interp_out_pattern=None, - interp_yearly_file=None, - run_interp=True, - overwrite=False, - required_shape=None, - max_workers=None, - interp_workers=None, - variables=None, - **interp_kwargs, - ): + def run_year(cls, + year, + area, + levels, + combined_out_pattern, + combined_yearly_file, + interp_out_pattern=None, + interp_yearly_file=None, + run_interp=True, + overwrite=False, + required_shape=None, + max_workers=None, + interp_workers=None, + variables=None, + **interp_kwargs): """Run routine for all months in the requested year. Parameters @@ -825,20 +776,18 @@ def run_year( """ if max_workers == 1: for month in range(1, 13): - cls.run_month( - year=year, - month=month, - area=area, - levels=levels, - combined_out_pattern=combined_out_pattern, - interp_out_pattern=interp_out_pattern, - run_interp=run_interp, - overwrite=overwrite, - required_shape=required_shape, - interp_workers=interp_workers, - variables=variables, - **interp_kwargs, - ) + cls.run_month(year=year, + month=month, + area=area, + levels=levels, + combined_out_pattern=combined_out_pattern, + interp_out_pattern=interp_out_pattern, + run_interp=run_interp, + overwrite=overwrite, + required_shape=required_shape, + interp_workers=interp_workers, + variables=variables, + **interp_kwargs) else: futures = {} with ThreadPoolExecutor(max_workers=max_workers) as exe: @@ -856,20 +805,15 @@ def run_year( required_shape=required_shape, interp_workers=interp_workers, variables=variables, - **interp_kwargs, - ) + **interp_kwargs) futures[future] = {'year': year, 'month': month} - logger.info( - f'Submitted future for year {year} and month ' - f'{month}.' - ) + logger.info(f'Submitted future for year {year} and month ' + f'{month}.') for future in as_completed(futures): future.result() v = futures[future] - logger.info( - f'Finished future for year {v["year"]} and month ' - f'{v["month"]}.' - ) + logger.info(f'Finished future for year {v["year"]} and month ' + f'{v["month"]}.') cls.make_yearly_file(year, combined_out_pattern, combined_yearly_file) @@ -890,10 +834,8 @@ def make_yearly_file(cls, year, file_pattern, yearly_file): yearly_file : str Name of yearly file made from monthly files. """ - msg = ( - f'Not all monthly files with file_patten {file_pattern} for ' - f'year {year} exist.' - ) + msg = (f'Not all monthly files with file_patten {file_pattern} for ' + f'year {year} exist.') assert cls.all_months_exist(year, file_pattern), msg files = [ @@ -911,16 +853,14 @@ def make_yearly_file(cls, year, file_pattern, yearly_file): logger.info(f'{yearly_file} already exists.') @classmethod - def _check_single_file( - cls, - res, - var_list=None, - check_nans=True, - check_heights=True, - max_interp_height=200, - required_shape=None, - max_workers=10, - ): + def _check_single_file(cls, + res, + var_list=None, + check_nans=True, + check_heights=True, + max_interp_height=200, + required_shape=None, + max_workers=10): """Make sure given files include the given variables. Check for NaNs and required shape. @@ -961,21 +901,15 @@ def _check_single_file( *res['latitude'].shape, *res['longitude'].shape, ) - good_shape = ( - 'NA' if required_shape is None else (res_shape == required_shape) - ) - good_hgts = ( - 'NA' - if not check_heights - else cls.check_heights( - res, - max_interp_height=max_interp_height, - max_workers=max_workers, - ) - ) - nan_pct = ( - 'NA' if not check_nans else cls.get_nan_pct(res, var_list=var_list) - ) + good_shape = ('NA' if required_shape is None else + (res_shape == required_shape)) + good_hgts = ('NA' if not check_heights else cls.check_heights( + res, + max_interp_height=max_interp_height, + max_workers=max_workers, + )) + nan_pct = ('NA' if not check_nans else cls.get_nan_pct( + res, var_list=var_list)) if not good_vars: mask = np.array([var not in res for var in var_list]) @@ -1007,23 +941,20 @@ def check_heights(cls, res, max_interp_height=200, max_workers=10): location and timestep """ gp = res['zg'].values - sfc_hgt = np.repeat( - res['orog'].values[:, np.newaxis, ...], gp.shape[1], axis=1 - ) + sfc_hgt = np.repeat(res['orog'].values[:, np.newaxis, ...], + gp.shape[1], + axis=1) heights = gp - sfc_hgt heights = heights.reshape(heights.shape[0], heights.shape[1], -1) checks = [] logger.info( - f'Checking heights with max_interp_height={max_interp_height}.' - ) + f'Checking heights with max_interp_height={max_interp_height}.') if max_workers == 1: for idt in range(heights.shape[0]): checks.append( cls._check_heights_single_ts( - heights[idt], max_interp_height=max_interp_height - ) - ) + heights[idt], max_interp_height=max_interp_height)) msg = f'Finished check for {idt + 1} of {heights.shape[0]}.' logger.debug(msg) else: @@ -1036,17 +967,13 @@ def check_heights(cls, res, max_interp_height=200, max_workers=10): max_interp_height=max_interp_height, ) futures.append(future) - msg = ( - f'Submitted height check for {idt + 1} of ' - f'{heights.shape[0]}' - ) + msg = (f'Submitted height check for {idt + 1} of ' + f'{heights.shape[0]}') logger.info(msg) for i, future in enumerate(as_completed(futures)): checks.append(future.result()) - msg = ( - f'Finished height check for {i + 1} of ' - f'{heights.shape[0]}' - ) + msg = (f'Finished height check for {i + 1} of ' + f'{heights.shape[0]}') logger.info(msg) return all(checks) @@ -1103,16 +1030,14 @@ def get_nan_pct(cls, res, var_list=None): return 100 * nan_count / elem_count @classmethod - def check_single_file( - cls, - file, - var_list=None, - check_nans=True, - check_heights=True, - max_interp_height=200, - required_shape=None, - max_workers=10, - ): + def check_single_file(cls, + file, + var_list=None, + check_nans=True, + check_heights=True, + max_interp_height=200, + required_shape=None, + max_workers=10): """Make sure given files include the given variables. Check for NaNs and required shape. @@ -1152,9 +1077,7 @@ def check_single_file( good_shape = None good_vars = None good_hgts = None - var_list = ( - var_list if var_list is not None else cls.DEFAULT_RENAMED_VARS - ) + var_list = (var_list if var_list is not None else cls.VALID_VARIABLES) try: res = xr.open_dataset(file) except Exception as e: @@ -1164,30 +1087,26 @@ def check_single_file( good = False if good: - out = cls._check_single_file( - res, - var_list, - check_nans=check_nans, - check_heights=check_heights, - max_interp_height=max_interp_height, - required_shape=required_shape, - max_workers=max_workers, - ) + out = cls._check_single_file(res, + var_list, + check_nans=check_nans, + check_heights=check_heights, + max_interp_height=max_interp_height, + required_shape=required_shape, + max_workers=max_workers) good_vars, good_shape, good_hgts, nan_pct = out return good_vars, good_shape, good_hgts, nan_pct @classmethod - def run_files_checks( - cls, - file_pattern, - var_list=None, - required_shape=None, - check_nans=True, - check_heights=True, - max_interp_height=200, - max_workers=None, - height_check_workers=10, - ): + def run_files_checks(cls, + file_pattern, + var_list=None, + required_shape=None, + check_nans=True, + check_heights=True, + max_interp_height=200, + max_workers=None, + height_check_workers=10): """Make sure given files include the given variables. Check for NaNs and required shape. @@ -1224,9 +1143,9 @@ def run_files_checks( files = glob(file_pattern) else: files = file_pattern - df = pd.DataFrame( - columns=['file', 'good_vars', 'good_shape', 'good_hgts', 'nan_pct'] - ) + df = pd.DataFrame(columns=[ + 'file', 'good_vars', 'good_shape', 'good_hgts', 'nan_pct' + ]) df['file'] = [os.path.basename(file) for file in files] if max_workers == 1: for i, file in enumerate(files): @@ -1238,36 +1157,29 @@ def run_files_checks( check_heights=check_heights, max_interp_height=max_interp_height, max_workers=height_check_workers, - required_shape=required_shape, - ) - df.at[i, df.columns[1:]] = out + required_shape=required_shape) + df.loc[i, df.columns[1:]] = out logger.info(f'Finished checking {file}.') else: futures = {} with ThreadPoolExecutor(max_workers=max_workers) as exe: for i, file in enumerate(files): - future = exe.submit( - cls.check_single_file, - file=file, - var_list=var_list, - check_nans=check_nans, - check_heights=check_heights, - max_interp_height=max_interp_height, - max_workers=height_check_workers, - required_shape=required_shape, - ) - msg = ( - f'Submitted file check future for {file}. Future ' - f'{i + 1} of {len(files)}.' - ) + future = exe.submit(cls.check_single_file, + file=file, + var_list=var_list, + check_nans=check_nans, + check_heights=check_heights, + max_interp_height=max_interp_height, + max_workers=height_check_workers, + required_shape=required_shape) + msg = (f'Submitted file check future for {file}. Future ' + f'{i + 1} of {len(files)}.') logger.info(msg) futures[future] = i for i, future in enumerate(as_completed(futures)): out = future.result() - df.at[futures[future], df.columns[1:]] = out - msg = ( - f'Finished checking {df["file"].iloc[futures[future]]}.' - f' Future {i + 1} of {len(files)}.' - ) + df.loc[futures[future], df.columns[1:]] = out + msg = (f'Finished checking {df["file"].iloc[futures[future]]}.' + f' Future {i + 1} of {len(files)}.') logger.info(msg) return df diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 41306bfc7..7a7f51910 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -8,6 +8,7 @@ as_completed, ) from glob import glob +from typing import ClassVar from warnings import warn import numpy as np @@ -22,7 +23,6 @@ init_logger(__name__, log_level='DEBUG') init_logger('sup3r', log_level='DEBUG') - logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ class LogLinInterpolator: max_log_height, linearly interpolate components above max_log_height meters, and save to file""" - DEFAULT_OUTPUT_HEIGHTS = { + DEFAULT_OUTPUT_HEIGHTS: ClassVar[dict] = { 'u': [40, 80, 120, 160, 200], 'v': [40, 80, 120, 160, 200], 'temperature': [10, 40, 80, 100, 120, 160, 200], @@ -69,10 +69,8 @@ def __init__( self.infile = infile self.outfile = outfile - msg = ( - 'output_heights must be a dictionary with variables as keys ' - f'and lists of heights as values. Received: {output_heights}.' - ) + msg = ('output_heights must be a dictionary with variables as keys ' + f'and lists of heights as values. Received: {output_heights}.') assert output_heights is None or isinstance(output_heights, dict), msg self.new_heights = output_heights or self.DEFAULT_OUTPUT_HEIGHTS @@ -84,11 +82,9 @@ def __init__( msg = f'{self.infile} does not exist. Skipping.' assert os.path.exists(self.infile), msg - msg = ( - f'Initializing {self.__class__.__name__} with infile={infile}, ' - f'outfile={outfile}, new_heights={self.new_heights}, ' - f'variables={variables}.' - ) + msg = (f'Initializing {self.__class__.__name__} with infile={infile}, ' + f'outfile={outfile}, new_heights={self.new_heights}, ' + f'variables={variables}.') logger.info(msg) def _load_single_var(self, variable): @@ -111,9 +107,9 @@ def _load_single_var(self, variable): logger.info(f'Loading {self.infile} for {variable}.') with xr.open_dataset(self.infile) as res: gp = res['zg'].values - sfc_hgt = np.repeat( - res['orog'].values[:, np.newaxis, ...], gp.shape[1], axis=1 - ) + sfc_hgt = np.repeat(res['orog'].values[:, np.newaxis, ...], + gp.shape[1], + axis=1) heights = gp - sfc_hgt input_heights = [] @@ -126,9 +122,9 @@ def _load_single_var(self, variable): height_arr = [] shape = (heights.shape[0], 1, *heights.shape[2:]) for height in input_heights: - var_arr.append( - res[f'{variable}_{height}m'].values[:, np.newaxis, ...] - ) + var_arr.append(res[f'{variable}_{height}m'].values[:, + np.newaxis, + ...]) height_arr.append(np.full(shape, height, dtype=np.float32)) if variable in res: @@ -163,8 +159,7 @@ def interpolate_vars(self, max_workers=None): if var not in ('u', 'v'): max_log_height = -np.inf logger.info( - f'Interpolating {var} to heights = {self.new_heights[var]}.' - ) + f'Interpolating {var} to heights = {self.new_heights[var]}.') self.new_data[var] = self.interp_var_to_height( var_array=arrs['data'], @@ -229,7 +224,7 @@ def init_dims(cls, old_ds, new_ds, dims): """ for var in dims: new_ds.createDimension(var, len(old_ds[var])) - _ = new_ds.createVariable(var, old_ds[var].dtype, dimensions=(var)) + _ = new_ds.createVariable(var, old_ds[var].dtype, dimensions=var) new_ds[var][:] = old_ds[var][:] new_ds[var].units = old_ds[var].units return new_ds @@ -238,54 +233,9 @@ def init_dims(cls, old_ds, new_ds, dims): def get_tmp_file(cls, file): """Get temp file for given file. Then only needed variables will be written to the given file.""" - tmp_file = file.replace(".nc", "_tmp.nc") + tmp_file = file.replace('.nc', '_tmp.nc') return tmp_file - @classmethod - def check_prune(cls, infile): - """Check if file has been pruned already.""" - - keep_vars = ('u_', 'v_', 'pressure_', 'temperature_', 'orog') - pruned = True - with Dataset(infile, 'r') as ds: - for var in ds.variables: - if not any(name in var for name in keep_vars): - pruned = False - return pruned - - @classmethod - def prune_output(cls, infile): - """Prune output file to keep just single level variables""" - - logger.info(f'Pruning {infile}.') - tmp_file = cls.get_tmp_file(infile) - keep_vars = ('u_', 'v_', 'pressure_', 'temperature_', 'orog') - with Dataset(infile, 'r') as old_ds: - with Dataset(tmp_file, 'w') as new_ds: - new_ds = cls.init_dims( - old_ds, new_ds, ('time', 'latitude', 'longitude') - ) - for var in old_ds.variables: - if any(name in var for name in keep_vars): - old_var = old_ds[var] - vals = old_var[:] - _ = new_ds.createVariable( - var, np.float32, dimensions=old_var.dimensions - ) - new_ds[var][:] = vals - if hasattr(old_var, 'units'): - new_ds[var].units = old_var.units - if hasattr(old_var, 'standard_name'): - standard_name = old_var.standard_name - new_ds[var].standard_name = standard_name - if hasattr(old_var, 'long_name'): - new_ds[var].long_name = old_var.long_name - os.system(f'mv {tmp_file} {infile}') - logger.info( - f'Finished pruning variables in {infile}. Moved ' - f'{tmp_file} to {infile}.' - ) - @classmethod def run( cls, @@ -328,8 +278,7 @@ def run( ) if os.path.exists(outfile) and not overwrite: logger.info( - f'{outfile} already exists and overwrite=False. ' 'Skipping.' - ) + f'{outfile} already exists and overwrite=False. Skipping.') else: log_interp.load() log_interp.interpolate_vars(max_workers=max_workers) @@ -374,8 +323,7 @@ def run_multiple( if max_workers == 1: for _, file in enumerate(infiles): outfile = os.path.basename(file).replace( - '.nc', '_all_interp.nc' - ) + '.nc', '_all_interp.nc') outfile = os.path.join(out_dir, outfile) cls.run( file, @@ -390,36 +338,29 @@ def run_multiple( with ThreadPoolExecutor(max_workers=max_workers) as exe: for i, file in enumerate(infiles): outfile = os.path.basename(file).replace( - '.nc', '_all_interp.nc' - ) + '.nc', '_all_interp.nc') outfile = os.path.join(out_dir, outfile) futures.append( - exe.submit( - cls.run, - file, - outfile, - output_heights=output_heights, - variables=variables, - max_log_height=max_log_height, - overwrite=overwrite, - ) - ) + exe.submit(cls.run, + file, + outfile, + output_heights=output_heights, + variables=variables, + max_log_height=max_log_height, + overwrite=overwrite)) logger.info( - f'{i + 1} of {len(infiles)} futures submitted.' - ) + f'{i + 1} of {len(infiles)} futures submitted.') for i, future in enumerate(as_completed(futures)): future.result() logger.info(f'{i + 1} of {len(futures)} futures complete.') @classmethod - def pbl_interp_to_height( - cls, - lev_array, - var_array, - levels, - fixed_level_mask=None, - max_log_height=100, - ): + def pbl_interp_to_height(cls, + lev_array, + var_array, + levels, + fixed_level_mask=None, + max_log_height=100): """Fit ws log law to data below max_log_height. Parameters @@ -464,19 +405,14 @@ def ws_log_profile(z, a, b): var_mask = (0 < lev_array_samp) & (lev_array_samp <= max_log_height) try: - popt, _ = curve_fit( - ws_log_profile, - lev_array_samp[var_mask], - var_array_samp[var_mask], - ) + popt, _ = curve_fit(ws_log_profile, lev_array_samp[var_mask], + var_array_samp[var_mask]) log_ws = ws_log_profile(levels[lev_mask], *popt) except Exception as e: - msg = ( - 'Log interp failed with (h, ws) = ' - f'({lev_array_samp[var_mask]}, ' - f'{var_array_samp[var_mask]}). {e} ' - 'Using linear interpolation.' - ) + msg = ('Log interp failed with (h, ws) = ' + f'({lev_array_samp[var_mask]}, ' + f'{var_array_samp[var_mask]}). {e} ' + 'Using linear interpolation.') good = False logger.warning(msg) warn(msg) @@ -488,14 +424,12 @@ def ws_log_profile(z, a, b): return log_ws, good @classmethod - def _interp_var_to_height( - cls, - lev_array, - var_array, - levels, - fixed_level_mask=None, - max_log_height=100, - ): + def _interp_var_to_height(cls, + lev_array, + var_array, + levels, + fixed_level_mask=None, + max_log_height=100): """Fit ws log law to wind data below max_log_height and linearly interpolate data above. Linearly interpolate non wind data. @@ -532,43 +466,35 @@ def _interp_var_to_height( good = True hgt_check = any(levels < max_log_height) and any( - lev_array < max_log_height - ) + lev_array < max_log_height) if hgt_check: log_ws, good = cls.pbl_interp_to_height( lev_array, var_array, levels, fixed_level_mask=fixed_level_mask, - max_log_height=max_log_height, - ) + max_log_height=max_log_height) if any(levels > max_log_height): lev_mask = levels >= max_log_height var_mask = lev_array >= max_log_height if len(lev_array[var_mask]) > 1: - lin_ws = interp1d( - lev_array[var_mask], - var_array[var_mask], - fill_value='extrapolate', - )(levels[lev_mask]) + lin_ws = interp1d(lev_array[var_mask], + var_array[var_mask], + fill_value='extrapolate')(levels[lev_mask]) elif len(lev_array) > 1: - msg = ( - 'Requested interpolation levels are outside the ' - f'available range: lev_array={lev_array}, ' - f'levels={levels}. Using linear extrapolation.' - ) - lin_ws = interp1d( - lev_array, var_array, fill_value='extrapolate' - )(levels[lev_mask]) + msg = ('Requested interpolation levels are outside the ' + f'available range: lev_array={lev_array}, ' + f'levels={levels}. Using linear extrapolation.') + lin_ws = interp1d(lev_array, + var_array, + fill_value='extrapolate')(levels[lev_mask]) good = False logger.warning(msg) warn(msg) else: - msg = ( - 'Data seems to be all NaNs. Something may have gone ' - 'wrong during download.' - ) + msg = ('Data seems to be all NaNs. Something may have gone ' + 'wrong during download.') raise OSError(msg) if log_ws is not None and lin_ws is not None: @@ -581,10 +507,8 @@ def _interp_var_to_height( out = lin_ws if log_ws is None and lin_ws is None: - msg = ( - f'No interpolation was performed for lev_array={lev_array} ' - f'and levels={levels}' - ) + msg = (f'No interpolation was performed for lev_array={lev_array} ' + f'and levels={levels}') raise RuntimeError(msg) return out, good @@ -623,15 +547,13 @@ def _get_timestep_interp_input(cls, lev_array, var_array, idt): return h_t, var_t, mask @classmethod - def interp_single_ts( - cls, - hgt_t, - var_t, - mask, - levels, - fixed_level_mask=None, - max_log_height=100, - ): + def interp_single_ts(cls, + hgt_t, + var_t, + mask, + levels, + fixed_level_mask=None, + max_log_height=100): """Perform interpolation for a single timestep specified by the index idt @@ -677,15 +599,13 @@ def interp_single_ts( return np.array(out_array), np.array(checks) @classmethod - def interp_var_to_height( - cls, - var_array, - lev_array, - levels, - fixed_level_mask=None, - max_log_height=100, - max_workers=None, - ): + def interp_var_to_height(cls, + var_array, + lev_array, + levels, + fixed_level_mask=None, + max_log_height=100, + max_workers=None): """Interpolate data array to given level(s) based on h_array. Interpolation is done using windspeed log profile and is done for every 'z' column of [var, h] data. @@ -720,8 +640,7 @@ def interp_var_to_height( Array of interpolated values. """ lev_array, levels = Interpolator.prep_level_interp( - var_array, lev_array, levels - ) + var_array, lev_array, levels) array_shape = var_array.shape @@ -735,8 +654,7 @@ def interp_var_to_height( if max_workers == 1: for idt in range(array_shape[0]): h_t, v_t, mask = cls._get_timestep_interp_input( - lev_array, var_array, idt - ) + lev_array, var_array, idt) out, checks = cls.interp_single_ts( h_t, v_t, @@ -749,15 +667,13 @@ def interp_var_to_height( total_checks.append(checks) logger.info( - f'{idt + 1} of {array_shape[0]} timesteps finished.' - ) + f'{idt + 1} of {array_shape[0]} timesteps finished.') else: with ProcessPoolExecutor(max_workers=max_workers) as exe: for idt in range(array_shape[0]): h_t, v_t, mask = cls._get_timestep_interp_input( - lev_array, var_array, idt - ) + lev_array, var_array, idt) future = exe.submit( cls.interp_single_ts, h_t, @@ -769,8 +685,7 @@ def interp_var_to_height( ) futures[future] = idt logger.info( - f'{idt + 1} of {array_shape[0]} futures submitted.' - ) + f'{idt + 1} of {array_shape[0]} futures submitted.') for i, future in enumerate(as_completed(futures)): out, checks = future.result() out_array[:, futures[future], :] = out @@ -780,22 +695,16 @@ def interp_var_to_height( total_checks = np.concatenate(total_checks) good_count = total_checks.sum() total_count = len(total_checks) - logger.info( - 'Percent of points interpolated without issue: ' - f'{100 * good_count / total_count:.2f}' - ) + logger.info('Percent of points interpolated without issue: ' + f'{100 * good_count / total_count:.2f}') # Reshape out_array if isinstance(levels, (float, np.float32, int)): shape = (1, array_shape[-4], array_shape[-2], array_shape[-1]) out_array = out_array.T.reshape(shape) else: - shape = ( - len(levels), - array_shape[-4], - array_shape[-2], - array_shape[-1], - ) + shape = (len(levels), array_shape[-4], array_shape[-2], + array_shape[-1]) out_array = out_array.T.reshape(shape) return out_array