-
Notifications
You must be signed in to change notification settings - Fork 688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extract_input_target_forcings add option for left-justification of train/eval #56
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,6 @@ | |
|
||
from typing import Any, Mapping, Sequence, Tuple, Union | ||
|
||
from graphcast import solar_radiation | ||
import numpy as np | ||
import pandas as pd | ||
import xarray | ||
|
@@ -37,15 +36,6 @@ | |
|
||
DAY_PROGRESS = "day_progress" | ||
YEAR_PROGRESS = "year_progress" | ||
_DERIVED_VARS = { | ||
DAY_PROGRESS, | ||
f"{DAY_PROGRESS}_sin", | ||
f"{DAY_PROGRESS}_cos", | ||
YEAR_PROGRESS, | ||
f"{YEAR_PROGRESS}_sin", | ||
f"{YEAR_PROGRESS}_cos", | ||
} | ||
TISR = "toa_incident_solar_radiation" | ||
|
||
|
||
def get_year_progress(seconds_since_epoch: np.ndarray) -> np.ndarray: | ||
|
@@ -133,7 +123,10 @@ def featurize_progress( | |
|
||
|
||
def add_derived_vars(data: xarray.Dataset) -> None: | ||
"""Adds year and day progress features to `data` in place if missing. | ||
"""Adds year and day progress features to `data` in place. | ||
|
||
NOTE: `toa_incident_solar_radiation` needs to be computed in this function | ||
as well. | ||
|
||
Args: | ||
data: Xarray dataset to which derived features will be added. | ||
|
@@ -154,71 +147,38 @@ def add_derived_vars(data: xarray.Dataset) -> None: | |
) | ||
batch_dim = ("batch",) if "batch" in data.dims else () | ||
|
||
# Add year progress features if missing. | ||
if YEAR_PROGRESS not in data.data_vars: | ||
year_progress = get_year_progress(seconds_since_epoch) | ||
data.update( | ||
featurize_progress( | ||
name=YEAR_PROGRESS, | ||
dims=batch_dim + ("time",), | ||
progress=year_progress, | ||
) | ||
) | ||
|
||
# Add day progress features if missing. | ||
if DAY_PROGRESS not in data.data_vars: | ||
longitude_coord = data.coords["lon"] | ||
day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data) | ||
data.update( | ||
featurize_progress( | ||
name=DAY_PROGRESS, | ||
dims=batch_dim + ("time",) + longitude_coord.dims, | ||
progress=day_progress, | ||
) | ||
) | ||
|
||
|
||
def add_tisr_var(data: xarray.Dataset) -> None: | ||
"""Adds TISR feature to `data` in place if missing. | ||
|
||
Args: | ||
data: Xarray dataset to which TISR feature will be added. | ||
|
||
Raises: | ||
ValueError if `datetime`, 'lat', or `lon` are not in `data` coordinates. | ||
""" | ||
|
||
if TISR in data.data_vars: | ||
return | ||
|
||
for coord in ("datetime", "lat", "lon"): | ||
if coord not in data.coords: | ||
raise ValueError(f"'{coord}' must be in `data` coordinates.") | ||
|
||
# Remove `batch` dimension of size one if present. An error will be raised if | ||
# the `batch` dimension exists and has size greater than one. | ||
data_no_batch = data.squeeze("batch") if "batch" in data.dims else data | ||
|
||
tisr = solar_radiation.get_toa_incident_solar_radiation_for_xarray( | ||
data_no_batch, use_jit=True | ||
# Add year progress features. | ||
year_progress = get_year_progress(seconds_since_epoch) | ||
data.update( | ||
featurize_progress( | ||
name=YEAR_PROGRESS, dims=batch_dim + ("time",), progress=year_progress | ||
) | ||
) | ||
|
||
if "batch" in data.dims: | ||
tisr = tisr.expand_dims("batch", axis=0) | ||
|
||
data.update({TISR: tisr}) | ||
# Add day progress features. | ||
longitude_coord = data.coords["lon"] | ||
day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data) | ||
data.update( | ||
featurize_progress( | ||
name=DAY_PROGRESS, | ||
dims=batch_dim + ("time",) + longitude_coord.dims, | ||
progress=day_progress, | ||
) | ||
) | ||
|
||
|
||
def extract_input_target_times( | ||
dataset: xarray.Dataset, | ||
input_duration: TimedeltaLike, | ||
target_lead_times: TargetLeadTimes, | ||
justify: str | ||
) -> Tuple[xarray.Dataset, xarray.Dataset]: | ||
"""Extracts inputs and targets for prediction, from a Dataset with a time dim. | ||
|
||
The input period is assumed to be contiguous (specified by a duration), but | ||
the targets can be a list of arbitrary lead times. | ||
|
||
|
||
Examples: | ||
|
||
# Use 18 hours of data as inputs, and two specific lead times as targets: | ||
|
@@ -256,6 +216,16 @@ def extract_input_target_times( | |
(inclusive) lead times, or a sequence of lead times. Lead times should be | ||
Timedeltas (or something convertible to). They are given relative to the | ||
final input timestep, and should be positive. | ||
justify: Defines whether inputs and targets are extracted from the beginning | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any chance you can add a test that this does what you'd expect? |
||
or end of the example batch. | ||
When using 'left' justify (default), the final input is defined as the 2nd time element | ||
of the example batch. Targets follow immediately thereafter as defined by the | ||
leadtime. | ||
Alternatively, 'right' justify is where the targets start with the last time | ||
element and the inputs are the two time elements preceding the first target. | ||
Note: It is important to realize that the first prediction can be no | ||
earlier than 12Z based on current construction of example batches. 00Z and | ||
06Z are inputs. | ||
|
||
Returns: | ||
inputs: | ||
|
@@ -270,23 +240,43 @@ def extract_input_target_times( | |
(target_lead_times, target_duration | ||
) = _process_target_lead_times_and_get_duration(target_lead_times) | ||
|
||
# Shift the coordinates for the time axis so that a timedelta of zero | ||
# corresponds to the forecast reference time. That is, the final timestep | ||
# that's available as input to the forecast, with all following timesteps | ||
# forming the target period which needs to be predicted. | ||
# This means the time coordinates are now forecast lead times. | ||
input_duration = pd.Timedelta(input_duration) | ||
time = dataset.coords["time"] | ||
dataset = dataset.assign_coords(time=time + target_duration - time[-1]) | ||
|
||
# Slice out targets: | ||
targets = dataset.sel({"time": target_lead_times}) | ||
# Slice out inputs and targets: | ||
if justify == 'left': | ||
# Inputs correspond to the first time elements within the input duration | ||
# Targets follow immediatly after per the target lead times | ||
target_start_time = int(input_duration.total_seconds()/3600/6) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't be 6h specific. |
||
target_end_time = int(input_duration.total_seconds()/3600/6) + int(target_duration.total_seconds()/3600/6) | ||
|
||
inputs = dataset.isel(time=slice(int(target_start_time))) | ||
inputs['time'] = inputs['time'] - input_duration + time[1] | ||
|
||
targets = dataset.isel(time=slice(target_start_time,target_end_time)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please take a look at the google python style guide. Some general comments:
|
||
targets['time'] = targets['time'] - input_duration + time[1] | ||
|
||
# targets = targets.assign_coords(time=time[1:target_end_time+1]) | ||
|
||
elif justify == 'right': | ||
# Shift the coordinates for the time axis so that a timedelta of zero | ||
# corresponds to the forecast reference time. That is, the final timestep | ||
# that's available as input to the forecast, with all following timesteps | ||
# forming the target period which needs to be predicted. | ||
# This means the time coordinates are now forecast lead times. | ||
dataset = dataset.assign_coords(time=time + target_duration - time[-1]) | ||
|
||
targets = dataset.sel({"time": target_lead_times}) | ||
# Both endpoints are inclusive with label-based slicing, so we offset by a | ||
# small epsilon to make one of the endpoints non-inclusive: | ||
zero = pd.Timedelta(0) | ||
epsilon = pd.Timedelta(1, "ns") | ||
inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)}) | ||
else: | ||
raise ValueError( | ||
"justify must either be 'left' or 'right'" | ||
) | ||
|
||
input_duration = pd.Timedelta(input_duration) | ||
# Both endpoints are inclusive with label-based slicing, so we offset by a | ||
# small epsilon to make one of the endpoints non-inclusive: | ||
zero = pd.Timedelta(0) | ||
epsilon = pd.Timedelta(1, "ns") | ||
inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)}) | ||
return inputs, targets | ||
|
||
|
||
|
@@ -311,8 +301,9 @@ def _process_target_lead_times_and_get_duration( | |
|
||
# A list of multiple (not necessarily contiguous) lead times: | ||
target_lead_times = [pd.Timedelta(x) for x in target_lead_times] | ||
target_lead_times.sort() | ||
target_lead_times.sort() | ||
target_duration = target_lead_times[-1] | ||
print(target_lead_times,target_duration) | ||
return target_lead_times, target_duration | ||
|
||
|
||
|
@@ -325,25 +316,24 @@ def extract_inputs_targets_forcings( | |
pressure_levels: Tuple[int, ...], | ||
input_duration: TimedeltaLike, | ||
target_lead_times: TargetLeadTimes, | ||
justify: str = 'left' | ||
) -> Tuple[xarray.Dataset, xarray.Dataset, xarray.Dataset]: | ||
"""Extracts inputs, targets and forcings according to requirements.""" | ||
dataset = dataset.sel(level=list(pressure_levels)) | ||
|
||
# "Forcings" include derived variables that do not exist in the original ERA5 | ||
# or HRES datasets, as well as other variables (e.g. tisr) that need to be | ||
# computed manually for the target lead times. Compute the requested ones. | ||
if set(forcing_variables) & _DERIVED_VARS: | ||
# "Forcings" are derived variables and do not exist in the original ERA5 or | ||
# HRES datasets. Compute them if they are not in `dataset`. | ||
if not set(forcing_variables).issubset(set(dataset.data_vars)): | ||
add_derived_vars(dataset) | ||
if set(forcing_variables) & {TISR}: | ||
add_tisr_var(dataset) | ||
|
||
# `datetime` is needed by add_derived_vars but breaks autoregressive rollouts. | ||
dataset = dataset.drop_vars("datetime") | ||
|
||
inputs, targets = extract_input_target_times( | ||
dataset, | ||
input_duration=input_duration, | ||
target_lead_times=target_lead_times) | ||
target_lead_times=target_lead_times, | ||
justify=justify) | ||
|
||
if set(forcing_variables) & set(target_variables): | ||
raise ValueError( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably better as an enum instead of string.