Skip to content

Commit

Permalink
Switch to Era5t data (#29)
Browse files Browse the repository at this point in the history
* Add paths and global variables in config

* Change lines to switch to ERA5T

* Add predict for xgboost model and modify add lags

* Modify unitary tests
  • Loading branch information
miltminz authored Dec 15, 2020
1 parent 8c840df commit 8a66bd1
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 27 deletions.
25 changes: 25 additions & 0 deletions pyro_risks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TEST_FR_VIIRS_XLSX_FALLBACK: str = f"{DATA_FALLBACK}/test_data_VIIRS.xlsx"
TEST_FR_VIIRS_JSON_FALLBACK: str = f"{DATA_FALLBACK}/test_data_VIIRS.json"
TEST_FR_ERA5_2019_FALLBACK: str = f"{DATA_FALLBACK}/test_data_ERA5_2019.nc"
TEST_FR_ERA5T_FALLBACK: str = f"{DATA_FALLBACK}/test_era5t_to_merge.nc"
TEST_FWI_FALLBACK: str = f"{DATA_FALLBACK}/test_data_FWI.csv"
TEST_FWI_TO_PREDICT: str = f"{DATA_FALLBACK}/fwi_test_to_predict.csv"
TEST_ERA_TO_PREDICT: str = f"{DATA_FALLBACK}/era_test_to_predict.csv"
Expand All @@ -37,7 +38,30 @@
CDS_API_KEY = os.getenv('CDS_API_KEY')

RFMODEL_PATH: str = f"{DATA_FALLBACK}/pyrorisk_rfc_111220.pkl"
RFMODEL_ERA5T_PATH: str = f"{DATA_FALLBACK}/pyrorisk_rfc_era5t_151220.pkl"
XGBMODEL_PATH: str = f"{DATA_FALLBACK}/pyrorisk_xgb_091220.pkl"
XGBMODEL_ERA5T_PATH: str = f"{DATA_FALLBACK}/pyrorisk_xgb_era5t_151220.pkl"

FWI_VARS = ['fwi', 'ffmc', 'dmc', 'dc', 'isi', 'bui', 'dsr']
WEATHER_VARS = [
'u10', 'v10', 'd2m', 't2m', 'fal', 'lai_hv', 'lai_lv', 'skt',
'asn', 'snowc', 'rsn', 'sde', 'sd', 'sf', 'smlt', 'stl1', 'stl2',
'stl3', 'stl4', 'slhf', 'ssr', 'str', 'sp', 'sshf', 'ssrd', 'strd', 'tsn', 'tp'
]
WEATHER_ERA5T_VARS = ['asn', 'd2m', 'e', 'es', 'fal', 'lai_hv', 'lai_lv', 'lblt',
'licd', 'lict', 'lmld', 'lmlt', 'lshf', 'ltlt', 'pev', 'ro', 'rsn', 'sd', 'sf', 'skt',
'slhf', 'smlt', 'sp', 'src', 'sro', 'sshf', 'ssr', 'ssrd', 'ssro', 'stl1', 'stl2', 'stl3',
'stl4', 'str', 'strd', 'swvl1', 'swvl2', 'swvl3', 'swvl4', 't2m', 'tp', 'tsn', 'u10', 'v10']

MODEL_ERA5T_VARS = ['str_max', 'str_mean', 'ffmc_min', 'str_min', 'ffmc_mean',
'str_mean_lag1', 'str_max_lag1', 'str_min_lag1', 'isi_min',
'ffmc_min_lag1', 'isi_mean', 'ffmc_mean_lag1', 'ffmc_std', 'ffmc_max',
'isi_min_lag1', 'isi_mean_lag1', 'ffmc_max_lag1', 'asn_std', 'strd_max',
'ssrd_min', 'strd_mean', 'isi_max', 'strd_min', 'd2m_min', 'asn_min',
'ssr_min', 'ffmc_min_lag3', 'ffmc_std_lag1', 'lai_hv_mean_lag7',
'str_max_lag3', 'str_mean_lag3', 'rsn_std_lag1', 'fwi_mean', 'ssr_mean',
'ssrd_mean', 'swvl1_mean', 'rsn_std_lag3', 'isi_max_lag1', 'd2m_mean',
'rsn_std']

MODEL_VARIABLES = ['ffmc_min', 'str_mean', 'str_min', 'str_max', 'ffmc_mean', 'isi_min',
'ffmc_min_lag1', 'strd_mean', 'isi_mean', 'strd_min', 'strd_max',
Expand All @@ -48,6 +72,7 @@
'strd_min_lag1', 'ffmc_min_lag3', 'ffmc_std_lag1', 'strd_mean_lag1',
'rsn_mean_lag1', 'fwi_mean', 'isi_max_lag1', 'sd_max', 'strd_max_lag1',
'rsn_mean', 'snowc_std_lag7', 'stl1_std_lag3']

TRAIN_SELECTED_DEP = ['Aisne', 'Alpes-Maritimes', 'Ardèche', 'Ariège', 'Aude', 'Aveyron',
'Cantal', 'Eure', 'Eure-et-Loir', 'Gironde', 'Haute-Corse', 'Hautes-Pyrénées',
'Hérault', 'Indre', 'Landes', 'Loiret', 'Lozère', 'Marne', 'Oise',
Expand Down
21 changes: 7 additions & 14 deletions pyro_risks/datasets/era_fwi_viirs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import pandas as pd

from pyro_risks.datasets import NASAFIRMS_VIIRS, ERA5Land
from pyro_risks.datasets import NASAFIRMS_VIIRS, ERA5Land, ERA5T
from pyro_risks.datasets.utils import get_intersection_range
from pyro_risks.datasets.fwi import GwisFwi
from pyro_risks import config as cfg

__all__ = ["MergedEraFwiViirs"]

Expand All @@ -28,19 +29,15 @@ def process_dataset_to_predict(fwi, era):

# Group fwi dataframe by day and department and compute min, max, mean, std
agg_fwi_df = fwi_df.groupby(['day', 'nom'])[
'fwi', 'ffmc', 'dmc', 'dc', 'isi', 'bui', 'dsr'
].agg(['min', 'max', 'mean', 'std']).reset_index()
cfg.FWI_VARS].agg(['min', 'max', 'mean', 'std']).reset_index()
agg_fwi_df.columns = ['day', 'nom'] + \
[x[0] + '_' + x[1] for x in agg_fwi_df.columns if x[1] != '']

logger.info("Finished aggregationg of FWI")

# Group weather dataframe by day and department and compute min, max, mean, std
agg_wth_df = weather.groupby(['time', 'nom'])[
'u10', 'v10', 'd2m', 't2m', 'fal', 'lai_hv', 'lai_lv', 'skt',
'asn', 'snowc', 'rsn', 'sde', 'sd', 'sf', 'smlt', 'stl1', 'stl2',
'stl3', 'stl4', 'slhf', 'ssr', 'str', 'sp', 'sshf', 'ssrd', 'strd', 'tsn', 'tp'
].agg(['min', 'max', 'mean', 'std']).reset_index()
cfg.WEATHER_ERA5T_VARS].agg(['min', 'max', 'mean', 'std']).reset_index()
agg_wth_df.columns = ['day', 'nom'] + \
[x[0] + '_' + x[1] for x in agg_wth_df.columns if x[1] != '']

Expand Down Expand Up @@ -75,7 +72,7 @@ def __init__(self, era_source_path=None, viirs_source_path=None, fwi_source_path
viirs_source_path (str, optional): Viirs data source path. Defaults to None.
fwi_source_path (str, optional): Fwi data source path. Defaults to None.
"""
weather = ERA5Land(era_source_path)
weather = ERA5T(era_source_path) # ERA5Land(era_source_path)
nasa_firms = NASAFIRMS_VIIRS(viirs_source_path)

# Time span selection
Expand All @@ -100,17 +97,13 @@ def __init__(self, era_source_path=None, viirs_source_path=None, fwi_source_path

# Group fwi dataframe by day and department and compute min, max, mean, std
agg_fwi_df = fwi_df.groupby(['day', 'departement'])[
'fwi', 'ffmc', 'dmc', 'dc', 'isi', 'bui', 'dsr'
].agg(['min', 'max', 'mean', 'std']).reset_index()
cfg.FWI_VARS].agg(['min', 'max', 'mean', 'std']).reset_index()
agg_fwi_df.columns = ['day', 'departement'] + \
[x[0] + '_' + x[1] for x in agg_fwi_df.columns if x[1] != '']

# Group weather dataframe by day and department and compute min, max, mean, std
agg_wth_df = weather.groupby(['time', 'nom'])[
'u10', 'v10', 'd2m', 't2m', 'fal', 'lai_hv', 'lai_lv', 'skt',
'asn', 'snowc', 'rsn', 'sde', 'sd', 'sf', 'smlt', 'stl1', 'stl2',
'stl3', 'stl4', 'slhf', 'ssr', 'str', 'sp', 'sshf', 'ssrd', 'strd', 'tsn', 'tp'
].agg(['min', 'max', 'mean', 'std']).reset_index()
cfg.WEATHER_ERA5T_VARS].agg(['min', 'max', 'mean', 'std']).reset_index()
agg_wth_df.columns = ['day', 'departement'] + \
[x[0] + '_' + x[1] for x in agg_wth_df.columns if x[1] != '']

Expand Down
31 changes: 22 additions & 9 deletions pyro_risks/models/predict.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import joblib
from urllib.request import urlopen
import xgboost

from pyro_risks import config as cfg
from pyro_risks.datasets.fwi import get_fwi_data_for_predict
from pyro_risks.datasets.ERA5 import get_data_era5land_for_predict
from pyro_risks.datasets.ERA5 import get_data_era5land_for_predict, get_data_era5t_for_predict
from pyro_risks.datasets.era_fwi_viirs import process_dataset_to_predict
from pyro_risks.models.score_v0 import add_lags

Expand All @@ -28,10 +29,13 @@ def __init__(self, which='RF'):
which (str, optional): Can be 'RF' for random forest or 'XGB' for xgboost. Defaults to 'RF'.
"""
if which == 'RF':
self.model_path = cfg.RFMODEL_PATH
self.model_path = cfg.RFMODEL_ERA5T_PATH
elif which == 'XGB':
self.model_path = cfg.XGBMODEL_PATH
self.model_path = cfg.XGBMODEL_ERA5T_PATH
else:
raise ValueError("Model can be only of type RF or XGB")
self.model = joblib.load(urlopen(self.model_path))
self._model_type = which

def get_input(self, day):
"""Returns for a given day data to feed into the model.
Expand All @@ -45,12 +49,15 @@ def get_input(self, day):
Returns:
pd.DataFrame
"""
model_cols = cfg.MODEL_VARIABLES
model_cols = cfg.MODEL_ERA5T_VARS
fwi = get_fwi_data_for_predict(day)
era = get_data_era5land_for_predict(day)
era = get_data_era5t_for_predict(day)
res_test = process_dataset_to_predict(fwi, era)
res_test = res_test.rename({'nom': 'departement'}, axis=1)
res_lags = add_lags(res_test, res_test.drop(['day', 'departement'], axis=1).columns)
# Add lags only for columns on which model was trained on
cols_lags = ['_'.join(x.split('_')[:-1]) for x in cfg.MODEL_ERA5T_VARS if '_lag' in x]
res_lags = add_lags(res_test, cols_lags)
# Select only rows corresponding to day
to_predict = res_lags.loc[res_lags['day'] == day]
to_predict = to_predict.drop('day', axis=1).set_index('departement')
# Some NaN due to the aggregations on departments with only one line (variables with std)
Expand All @@ -68,9 +75,15 @@ def predict(self, day, country='France'):
country (str, optional): Defaults to 'France'.
Returns:
dict: keys are departements and values model probability predictions for label 1 (fire)
dict: keys are departements, values dictionaries whose keys are score and explainability
and values probability predictions for label 1 (fire) and feature contributions to predictions
respectively
"""
sample = self.get_input(day)
predictions = self.model.predict_proba(sample.values)
res = dict(zip(sample.index, predictions[:, 1].round(3)))
if self._model_type == 'RF':
predictions = self.model.predict_proba(sample.values)
res = dict(zip(sample.index, predictions[:, 1].round(3)))
elif self._model_type == 'XGB':
predictions = self.model.predict(xgboost.DMatrix(sample))
res = dict(zip(sample.index, predictions.round(3)))
return {x: {'score': res[x], 'explainability': None} for x in res}
3 changes: 2 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,12 @@ def test_era5t(self):

def test_MergedEraFwiViirs(self):
ds = era_fwi_viirs.MergedEraFwiViirs(
era_source_path=cfg.TEST_FR_ERA5_2019_FALLBACK,
era_source_path=cfg.TEST_FR_ERA5T_FALLBACK,
viirs_source_path=None,
fwi_source_path=cfg.TEST_FWI_FALLBACK,
)
self.assertIsInstance(ds, pd.DataFrame)
self.assertTrue(len(ds) > 0)

def test_call_era5land(self):
with tempfile.TemporaryDirectory() as tmp:
Expand Down
6 changes: 3 additions & 3 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def test_xgb_model(self):
def test_pyrorisk(self):
pr = predict.PyroRisk(which='RF')
self.assertEqual(pr.model.n_estimators, 500)
self.assertEqual(pr.model_path, cfg.RFMODEL_PATH)
self.assertEqual(pr.model_path, cfg.RFMODEL_ERA5T_PATH)
res = pr.get_input('2020-05-05')
self.assertIsInstance(res, pd.DataFrame)
self.assertEqual(res.shape, (93, 41))
self.assertEqual(res.shape, (93, 40))
preds = pr.predict('2020-05-05')
self.assertEqual(len(preds), 93)
self.assertEqual(preds['Ardennes'], {'score': 0.11, 'explainability': None})
self.assertEqual(preds['Ardennes'], {'score': 0.246, 'explainability': None})


if __name__ == "__main__":
Expand Down

0 comments on commit 8a66bd1

Please sign in to comment.