diff --git a/graphcast/data_utils.py b/graphcast/data_utils.py index e5fea84..b683219 100644 --- a/graphcast/data_utils.py +++ b/graphcast/data_utils.py @@ -15,7 +15,6 @@ from typing import Any, Mapping, Sequence, Tuple, Union -from graphcast import solar_radiation import numpy as np import pandas as pd import xarray @@ -37,15 +36,6 @@ DAY_PROGRESS = "day_progress" YEAR_PROGRESS = "year_progress" -_DERIVED_VARS = { - DAY_PROGRESS, - f"{DAY_PROGRESS}_sin", - f"{DAY_PROGRESS}_cos", - YEAR_PROGRESS, - f"{YEAR_PROGRESS}_sin", - f"{YEAR_PROGRESS}_cos", -} -TISR = "toa_incident_solar_radiation" def get_year_progress(seconds_since_epoch: np.ndarray) -> np.ndarray: @@ -133,7 +123,10 @@ def featurize_progress( def add_derived_vars(data: xarray.Dataset) -> None: - """Adds year and day progress features to `data` in place if missing. + """Adds year and day progress features to `data` in place. + + NOTE: `toa_incident_solar_radiation` needs to be computed in this function + as well. Args: data: Xarray dataset to which derived features will be added. @@ -154,71 +147,38 @@ def add_derived_vars(data: xarray.Dataset) -> None: ) batch_dim = ("batch",) if "batch" in data.dims else () - # Add year progress features if missing. - if YEAR_PROGRESS not in data.data_vars: - year_progress = get_year_progress(seconds_since_epoch) - data.update( - featurize_progress( - name=YEAR_PROGRESS, - dims=batch_dim + ("time",), - progress=year_progress, - ) - ) - - # Add day progress features if missing. - if DAY_PROGRESS not in data.data_vars: - longitude_coord = data.coords["lon"] - day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data) - data.update( - featurize_progress( - name=DAY_PROGRESS, - dims=batch_dim + ("time",) + longitude_coord.dims, - progress=day_progress, - ) - ) - - -def add_tisr_var(data: xarray.Dataset) -> None: - """Adds TISR feature to `data` in place if missing. - - Args: - data: Xarray dataset to which TISR feature will be added. - - Raises: - ValueError if `datetime`, 'lat', or `lon` are not in `data` coordinates. - """ - - if TISR in data.data_vars: - return - - for coord in ("datetime", "lat", "lon"): - if coord not in data.coords: - raise ValueError(f"'{coord}' must be in `data` coordinates.") - - # Remove `batch` dimension of size one if present. An error will be raised if - # the `batch` dimension exists and has size greater than one. - data_no_batch = data.squeeze("batch") if "batch" in data.dims else data - - tisr = solar_radiation.get_toa_incident_solar_radiation_for_xarray( - data_no_batch, use_jit=True + # Add year progress features. + year_progress = get_year_progress(seconds_since_epoch) + data.update( + featurize_progress( + name=YEAR_PROGRESS, dims=batch_dim + ("time",), progress=year_progress + ) ) - if "batch" in data.dims: - tisr = tisr.expand_dims("batch", axis=0) - - data.update({TISR: tisr}) + # Add day progress features. + longitude_coord = data.coords["lon"] + day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data) + data.update( + featurize_progress( + name=DAY_PROGRESS, + dims=batch_dim + ("time",) + longitude_coord.dims, + progress=day_progress, + ) + ) def extract_input_target_times( dataset: xarray.Dataset, input_duration: TimedeltaLike, target_lead_times: TargetLeadTimes, + justify: str ) -> Tuple[xarray.Dataset, xarray.Dataset]: """Extracts inputs and targets for prediction, from a Dataset with a time dim. The input period is assumed to be contiguous (specified by a duration), but the targets can be a list of arbitrary lead times. + Examples: # Use 18 hours of data as inputs, and two specific lead times as targets: @@ -256,6 +216,16 @@ def extract_input_target_times( (inclusive) lead times, or a sequence of lead times. Lead times should be Timedeltas (or something convertible to). They are given relative to the final input timestep, and should be positive. + justify: Defines whether inputs and targets are extracted from the beginning + or end of the example batch. + When using 'left' justify (default), the final input is defined as the 2nd time element + of the example batch. Targets follow immediately thereafter as defined by the + leadtime. + Alternatively, 'right' justify is where the targets start with the last time + element and the inputs are the two time elements preceding the first target. + Note: It is important to realize that the first prediction can be no + earlier than 12Z based on current construction of example batches. 00Z and + 06Z are inputs. Returns: inputs: @@ -270,23 +240,43 @@ def extract_input_target_times( (target_lead_times, target_duration ) = _process_target_lead_times_and_get_duration(target_lead_times) - # Shift the coordinates for the time axis so that a timedelta of zero - # corresponds to the forecast reference time. That is, the final timestep - # that's available as input to the forecast, with all following timesteps - # forming the target period which needs to be predicted. - # This means the time coordinates are now forecast lead times. + input_duration = pd.Timedelta(input_duration) time = dataset.coords["time"] - dataset = dataset.assign_coords(time=time + target_duration - time[-1]) - # Slice out targets: - targets = dataset.sel({"time": target_lead_times}) + # Slice out inputs and targets: + if justify == 'left': + # Inputs correspond to the first time elements within the input duration + # Targets follow immediatly after per the target lead times + target_start_time = int(input_duration.total_seconds()/3600/6) + target_end_time = int(input_duration.total_seconds()/3600/6) + int(target_duration.total_seconds()/3600/6) + + inputs = dataset.isel(time=slice(int(target_start_time))) + inputs['time'] = inputs['time'] - input_duration + time[1] + + targets = dataset.isel(time=slice(target_start_time,target_end_time)) + targets['time'] = targets['time'] - input_duration + time[1] + + # targets = targets.assign_coords(time=time[1:target_end_time+1]) + + elif justify == 'right': + # Shift the coordinates for the time axis so that a timedelta of zero + # corresponds to the forecast reference time. That is, the final timestep + # that's available as input to the forecast, with all following timesteps + # forming the target period which needs to be predicted. + # This means the time coordinates are now forecast lead times. + dataset = dataset.assign_coords(time=time + target_duration - time[-1]) + + targets = dataset.sel({"time": target_lead_times}) + # Both endpoints are inclusive with label-based slicing, so we offset by a + # small epsilon to make one of the endpoints non-inclusive: + zero = pd.Timedelta(0) + epsilon = pd.Timedelta(1, "ns") + inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)}) + else: + raise ValueError( + "justify must either be 'left' or 'right'" + ) - input_duration = pd.Timedelta(input_duration) - # Both endpoints are inclusive with label-based slicing, so we offset by a - # small epsilon to make one of the endpoints non-inclusive: - zero = pd.Timedelta(0) - epsilon = pd.Timedelta(1, "ns") - inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)}) return inputs, targets @@ -311,8 +301,9 @@ def _process_target_lead_times_and_get_duration( # A list of multiple (not necessarily contiguous) lead times: target_lead_times = [pd.Timedelta(x) for x in target_lead_times] - target_lead_times.sort() + target_lead_times.sort() target_duration = target_lead_times[-1] + print(target_lead_times,target_duration) return target_lead_times, target_duration @@ -325,17 +316,15 @@ def extract_inputs_targets_forcings( pressure_levels: Tuple[int, ...], input_duration: TimedeltaLike, target_lead_times: TargetLeadTimes, + justify: str = 'left' ) -> Tuple[xarray.Dataset, xarray.Dataset, xarray.Dataset]: """Extracts inputs, targets and forcings according to requirements.""" dataset = dataset.sel(level=list(pressure_levels)) - # "Forcings" include derived variables that do not exist in the original ERA5 - # or HRES datasets, as well as other variables (e.g. tisr) that need to be - # computed manually for the target lead times. Compute the requested ones. - if set(forcing_variables) & _DERIVED_VARS: + # "Forcings" are derived variables and do not exist in the original ERA5 or + # HRES datasets. Compute them if they are not in `dataset`. + if not set(forcing_variables).issubset(set(dataset.data_vars)): add_derived_vars(dataset) - if set(forcing_variables) & {TISR}: - add_tisr_var(dataset) # `datetime` is needed by add_derived_vars but breaks autoregressive rollouts. dataset = dataset.drop_vars("datetime") @@ -343,7 +332,8 @@ def extract_inputs_targets_forcings( inputs, targets = extract_input_target_times( dataset, input_duration=input_duration, - target_lead_times=target_lead_times) + target_lead_times=target_lead_times, + justify=justify) if set(forcing_variables) & set(target_variables): raise ValueError(