diff --git a/pyro_risks/config.py b/pyro_risks/config.py index 10dffb4..c07fe5e 100644 --- a/pyro_risks/config.py +++ b/pyro_risks/config.py @@ -24,6 +24,7 @@ TEST_FR_VIIRS_XLSX_FALLBACK: str = f"{DATA_FALLBACK}/test_data_VIIRS.xlsx" TEST_FR_VIIRS_JSON_FALLBACK: str = f"{DATA_FALLBACK}/test_data_VIIRS.json" TEST_FR_ERA5_2019_FALLBACK: str = f"{DATA_FALLBACK}/test_data_ERA5_2019.nc" +TEST_FR_ERA5T_FALLBACK: str = f"{DATA_FALLBACK}/test_era5t_to_merge.nc" TEST_FWI_FALLBACK: str = f"{DATA_FALLBACK}/test_data_FWI.csv" TEST_FWI_TO_PREDICT: str = f"{DATA_FALLBACK}/fwi_test_to_predict.csv" TEST_ERA_TO_PREDICT: str = f"{DATA_FALLBACK}/era_test_to_predict.csv" @@ -37,7 +38,30 @@ CDS_API_KEY = os.getenv('CDS_API_KEY') RFMODEL_PATH: str = f"{DATA_FALLBACK}/pyrorisk_rfc_111220.pkl" +RFMODEL_ERA5T_PATH: str = f"{DATA_FALLBACK}/pyrorisk_rfc_era5t_151220.pkl" XGBMODEL_PATH: str = f"{DATA_FALLBACK}/pyrorisk_xgb_091220.pkl" +XGBMODEL_ERA5T_PATH: str = f"{DATA_FALLBACK}/pyrorisk_xgb_era5t_151220.pkl" + +FWI_VARS = ['fwi', 'ffmc', 'dmc', 'dc', 'isi', 'bui', 'dsr'] +WEATHER_VARS = [ + 'u10', 'v10', 'd2m', 't2m', 'fal', 'lai_hv', 'lai_lv', 'skt', + 'asn', 'snowc', 'rsn', 'sde', 'sd', 'sf', 'smlt', 'stl1', 'stl2', + 'stl3', 'stl4', 'slhf', 'ssr', 'str', 'sp', 'sshf', 'ssrd', 'strd', 'tsn', 'tp' +] +WEATHER_ERA5T_VARS = ['asn', 'd2m', 'e', 'es', 'fal', 'lai_hv', 'lai_lv', 'lblt', + 'licd', 'lict', 'lmld', 'lmlt', 'lshf', 'ltlt', 'pev', 'ro', 'rsn', 'sd', 'sf', 'skt', + 'slhf', 'smlt', 'sp', 'src', 'sro', 'sshf', 'ssr', 'ssrd', 'ssro', 'stl1', 'stl2', 'stl3', + 'stl4', 'str', 'strd', 'swvl1', 'swvl2', 'swvl3', 'swvl4', 't2m', 'tp', 'tsn', 'u10', 'v10'] + +MODEL_ERA5T_VARS = ['str_max', 'str_mean', 'ffmc_min', 'str_min', 'ffmc_mean', + 'str_mean_lag1', 'str_max_lag1', 'str_min_lag1', 'isi_min', + 'ffmc_min_lag1', 'isi_mean', 'ffmc_mean_lag1', 'ffmc_std', 'ffmc_max', + 'isi_min_lag1', 'isi_mean_lag1', 'ffmc_max_lag1', 'asn_std', 'strd_max', + 'ssrd_min', 'strd_mean', 'isi_max', 'strd_min', 'd2m_min', 'asn_min', + 'ssr_min', 'ffmc_min_lag3', 'ffmc_std_lag1', 'lai_hv_mean_lag7', + 'str_max_lag3', 'str_mean_lag3', 'rsn_std_lag1', 'fwi_mean', 'ssr_mean', + 'ssrd_mean', 'swvl1_mean', 'rsn_std_lag3', 'isi_max_lag1', 'd2m_mean', + 'rsn_std'] MODEL_VARIABLES = ['ffmc_min', 'str_mean', 'str_min', 'str_max', 'ffmc_mean', 'isi_min', 'ffmc_min_lag1', 'strd_mean', 'isi_mean', 'strd_min', 'strd_max', @@ -48,6 +72,7 @@ 'strd_min_lag1', 'ffmc_min_lag3', 'ffmc_std_lag1', 'strd_mean_lag1', 'rsn_mean_lag1', 'fwi_mean', 'isi_max_lag1', 'sd_max', 'strd_max_lag1', 'rsn_mean', 'snowc_std_lag7', 'stl1_std_lag3'] + TRAIN_SELECTED_DEP = ['Aisne', 'Alpes-Maritimes', 'Ardèche', 'Ariège', 'Aude', 'Aveyron', 'Cantal', 'Eure', 'Eure-et-Loir', 'Gironde', 'Haute-Corse', 'Hautes-Pyrénées', 'Hérault', 'Indre', 'Landes', 'Loiret', 'Lozère', 'Marne', 'Oise', diff --git a/pyro_risks/datasets/era_fwi_viirs.py b/pyro_risks/datasets/era_fwi_viirs.py index 0748534..fe89d08 100644 --- a/pyro_risks/datasets/era_fwi_viirs.py +++ b/pyro_risks/datasets/era_fwi_viirs.py @@ -1,9 +1,10 @@ import logging import pandas as pd -from pyro_risks.datasets import NASAFIRMS_VIIRS, ERA5Land +from pyro_risks.datasets import NASAFIRMS_VIIRS, ERA5Land, ERA5T from pyro_risks.datasets.utils import get_intersection_range from pyro_risks.datasets.fwi import GwisFwi +from pyro_risks import config as cfg __all__ = ["MergedEraFwiViirs"] @@ -28,8 +29,7 @@ def process_dataset_to_predict(fwi, era): # Group fwi dataframe by day and department and compute min, max, mean, std agg_fwi_df = fwi_df.groupby(['day', 'nom'])[ - 'fwi', 'ffmc', 'dmc', 'dc', 'isi', 'bui', 'dsr' - ].agg(['min', 'max', 'mean', 'std']).reset_index() + cfg.FWI_VARS].agg(['min', 'max', 'mean', 'std']).reset_index() agg_fwi_df.columns = ['day', 'nom'] + \ [x[0] + '_' + x[1] for x in agg_fwi_df.columns if x[1] != ''] @@ -37,10 +37,7 @@ def process_dataset_to_predict(fwi, era): # Group weather dataframe by day and department and compute min, max, mean, std agg_wth_df = weather.groupby(['time', 'nom'])[ - 'u10', 'v10', 'd2m', 't2m', 'fal', 'lai_hv', 'lai_lv', 'skt', - 'asn', 'snowc', 'rsn', 'sde', 'sd', 'sf', 'smlt', 'stl1', 'stl2', - 'stl3', 'stl4', 'slhf', 'ssr', 'str', 'sp', 'sshf', 'ssrd', 'strd', 'tsn', 'tp' - ].agg(['min', 'max', 'mean', 'std']).reset_index() + cfg.WEATHER_ERA5T_VARS].agg(['min', 'max', 'mean', 'std']).reset_index() agg_wth_df.columns = ['day', 'nom'] + \ [x[0] + '_' + x[1] for x in agg_wth_df.columns if x[1] != ''] @@ -75,7 +72,7 @@ def __init__(self, era_source_path=None, viirs_source_path=None, fwi_source_path viirs_source_path (str, optional): Viirs data source path. Defaults to None. fwi_source_path (str, optional): Fwi data source path. Defaults to None. """ - weather = ERA5Land(era_source_path) + weather = ERA5T(era_source_path) # ERA5Land(era_source_path) nasa_firms = NASAFIRMS_VIIRS(viirs_source_path) # Time span selection @@ -100,17 +97,13 @@ def __init__(self, era_source_path=None, viirs_source_path=None, fwi_source_path # Group fwi dataframe by day and department and compute min, max, mean, std agg_fwi_df = fwi_df.groupby(['day', 'departement'])[ - 'fwi', 'ffmc', 'dmc', 'dc', 'isi', 'bui', 'dsr' - ].agg(['min', 'max', 'mean', 'std']).reset_index() + cfg.FWI_VARS].agg(['min', 'max', 'mean', 'std']).reset_index() agg_fwi_df.columns = ['day', 'departement'] + \ [x[0] + '_' + x[1] for x in agg_fwi_df.columns if x[1] != ''] # Group weather dataframe by day and department and compute min, max, mean, std agg_wth_df = weather.groupby(['time', 'nom'])[ - 'u10', 'v10', 'd2m', 't2m', 'fal', 'lai_hv', 'lai_lv', 'skt', - 'asn', 'snowc', 'rsn', 'sde', 'sd', 'sf', 'smlt', 'stl1', 'stl2', - 'stl3', 'stl4', 'slhf', 'ssr', 'str', 'sp', 'sshf', 'ssrd', 'strd', 'tsn', 'tp' - ].agg(['min', 'max', 'mean', 'std']).reset_index() + cfg.WEATHER_ERA5T_VARS].agg(['min', 'max', 'mean', 'std']).reset_index() agg_wth_df.columns = ['day', 'departement'] + \ [x[0] + '_' + x[1] for x in agg_wth_df.columns if x[1] != ''] diff --git a/pyro_risks/models/predict.py b/pyro_risks/models/predict.py index dd62cb8..1aff81d 100644 --- a/pyro_risks/models/predict.py +++ b/pyro_risks/models/predict.py @@ -1,9 +1,10 @@ import joblib from urllib.request import urlopen +import xgboost from pyro_risks import config as cfg from pyro_risks.datasets.fwi import get_fwi_data_for_predict -from pyro_risks.datasets.ERA5 import get_data_era5land_for_predict +from pyro_risks.datasets.ERA5 import get_data_era5land_for_predict, get_data_era5t_for_predict from pyro_risks.datasets.era_fwi_viirs import process_dataset_to_predict from pyro_risks.models.score_v0 import add_lags @@ -28,10 +29,13 @@ def __init__(self, which='RF'): which (str, optional): Can be 'RF' for random forest or 'XGB' for xgboost. Defaults to 'RF'. """ if which == 'RF': - self.model_path = cfg.RFMODEL_PATH + self.model_path = cfg.RFMODEL_ERA5T_PATH elif which == 'XGB': - self.model_path = cfg.XGBMODEL_PATH + self.model_path = cfg.XGBMODEL_ERA5T_PATH + else: + raise ValueError("Model can be only of type RF or XGB") self.model = joblib.load(urlopen(self.model_path)) + self._model_type = which def get_input(self, day): """Returns for a given day data to feed into the model. @@ -45,12 +49,15 @@ def get_input(self, day): Returns: pd.DataFrame """ - model_cols = cfg.MODEL_VARIABLES + model_cols = cfg.MODEL_ERA5T_VARS fwi = get_fwi_data_for_predict(day) - era = get_data_era5land_for_predict(day) + era = get_data_era5t_for_predict(day) res_test = process_dataset_to_predict(fwi, era) res_test = res_test.rename({'nom': 'departement'}, axis=1) - res_lags = add_lags(res_test, res_test.drop(['day', 'departement'], axis=1).columns) + # Add lags only for columns on which model was trained on + cols_lags = ['_'.join(x.split('_')[:-1]) for x in cfg.MODEL_ERA5T_VARS if '_lag' in x] + res_lags = add_lags(res_test, cols_lags) + # Select only rows corresponding to day to_predict = res_lags.loc[res_lags['day'] == day] to_predict = to_predict.drop('day', axis=1).set_index('departement') # Some NaN due to the aggregations on departments with only one line (variables with std) @@ -68,9 +75,15 @@ def predict(self, day, country='France'): country (str, optional): Defaults to 'France'. Returns: - dict: keys are departements and values model probability predictions for label 1 (fire) + dict: keys are departements, values dictionaries whose keys are score and explainability + and values probability predictions for label 1 (fire) and feature contributions to predictions + respectively """ sample = self.get_input(day) - predictions = self.model.predict_proba(sample.values) - res = dict(zip(sample.index, predictions[:, 1].round(3))) + if self._model_type == 'RF': + predictions = self.model.predict_proba(sample.values) + res = dict(zip(sample.index, predictions[:, 1].round(3))) + elif self._model_type == 'XGB': + predictions = self.model.predict(xgboost.DMatrix(sample)) + res = dict(zip(sample.index, predictions.round(3))) return {x: {'score': res[x], 'explainability': None} for x in res} diff --git a/test/test_datasets.py b/test/test_datasets.py index 1e5fd1f..fdcb0e3 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -428,11 +428,12 @@ def test_era5t(self): def test_MergedEraFwiViirs(self): ds = era_fwi_viirs.MergedEraFwiViirs( - era_source_path=cfg.TEST_FR_ERA5_2019_FALLBACK, + era_source_path=cfg.TEST_FR_ERA5T_FALLBACK, viirs_source_path=None, fwi_source_path=cfg.TEST_FWI_FALLBACK, ) self.assertIsInstance(ds, pd.DataFrame) + self.assertTrue(len(ds) > 0) def test_call_era5land(self): with tempfile.TemporaryDirectory() as tmp: diff --git a/test/test_models.py b/test/test_models.py index 11ec338..536ee80 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -79,13 +79,13 @@ def test_xgb_model(self): def test_pyrorisk(self): pr = predict.PyroRisk(which='RF') self.assertEqual(pr.model.n_estimators, 500) - self.assertEqual(pr.model_path, cfg.RFMODEL_PATH) + self.assertEqual(pr.model_path, cfg.RFMODEL_ERA5T_PATH) res = pr.get_input('2020-05-05') self.assertIsInstance(res, pd.DataFrame) - self.assertEqual(res.shape, (93, 41)) + self.assertEqual(res.shape, (93, 40)) preds = pr.predict('2020-05-05') self.assertEqual(len(preds), 93) - self.assertEqual(preds['Ardennes'], {'score': 0.11, 'explainability': None}) + self.assertEqual(preds['Ardennes'], {'score': 0.246, 'explainability': None}) if __name__ == "__main__":