Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

design_info is needed to make predictions on new data #121

Open
samuelefiorini opened this issue Jun 22, 2023 · 4 comments
Open

design_info is needed to make predictions on new data #121

samuelefiorini opened this issue Jun 22, 2023 · 4 comments

Comments

@samuelefiorini
Copy link

According to the documentation

The design info is not useful in general (to reproduce results, make predictions), and not dumping it will save a lot of time.

However trying to make predictions on new data using a model restored via forecaster.load_forecast_result(path, load_design_info=False) leads to the following error.

File /var/lang/lib/python3.10/site-packages/greykite/algo/common/ml_models.py:715, in predict_ml(fut_df, trained_model)
    713 y_col = trained_model["y_col"]
    714 ml_model = trained_model["ml_model"]
--> 715 x_design_info = trained_model["x_design_info"]
    716 drop_intercept_col = trained_model["drop_intercept_col"]
    717 min_admissible_value = trained_model["min_admissible_value"]

KeyError: 'x_design_info'

This totally makes sense when looking at greykite.algo.common.ml_models.predict_ml as the variable x_design_info is used by patsy to build the design matrix (see here).

On the other hand, dumping design_info does not only imply dealing with a bigger artifact, but may be impossible due to system limitations on the generated filename.

As an example, this is what happens in my case.

OSError: [Errno 36] File name too long: '/opt/ml/model/5f4cafc99b894af398c02013e13348e2/artifacts/forecast_result/grid_search/best_estimator_/steps/2_key/1_key/model_dict/x_design_info__value__/factor_infos/EvalFactor("C(Q(\'dow_hr\'), levels=[\'1_00\', \'1_01\', \'1_02\', \'1_03\', \'1_04\', \'1_05\', \'1_06\', \'1_07\', \'1_08\', \'1_09\', \'1_10\', \'1_11\', \'1_12\', \'1_13\', \'1_14\', \'1_15\', \'1_16\', \'1_17\', \'1_18\', \'1_19\', \'1_20\', \'1_21\', \'1_22\', \'1_23\', \'2_00\', \'2_01\', \'2_02\', \'2_03\', \'2_04\', \'2_05\', \'2_06\', \'2_07\', \'2_08\', \'2_09\', \'2_10\', \'2_11\', \'2_12\', \'2_13\', \'2_14\', \'2_15\', \'2_16\', \'2_17\', \'2_18\', \'2_19\', \'2_20\', \'2_21\', \'2_22\', \'2_23\', \'3_00\', \'3_01\', \'3_02\', \'3_03\', \'3_04\', \'3_05\', \'3_06\', \'3_07\', \'3_08\', \'3_09\', \'3_10\', \'3_11\', \'3_12\', \'3_13\', \'3_14\', \'3_15\', \'3_16\', \'3_17\', \'3_18\', \'3_19\', \'3_20\', \'3_21\', \'3_22\', \'3_23\', \'4_00\', \'4_01\', \'4_02\', \'4_03\', \'4_04\', \'4_05\', \'4_06\', \'4_07\', \'4_08\', \'4_09\', \'4_10\', \'4_11\', \'4_12\', \'4_13\', \'4_14\', \'4_15\', \'4_16\', \'4_17\', \'4_18\', \'4_19\', \'4_20\', \'4_21\', \'4_22\', \'4_23\', \'5_00\', \'5_01\', \'5_02\', \'5_03\', \'5_04\', \'5_05\', \'5_06\', \'5_07\', \'5_08\', \'5_09\', \'5_10\', \'5_11\', \'5_12\', \'5_13\', \'5_14\', \'5_15\', \'5_16\', \'5_17\', \'5_18\', \'5_19\', \'5_20\', \'5_21\', \'5_22\', \'5_23\', \'6_00\', \'6_01\', \'6_02\', \'6_03\', \'6_04\', \'6_05\', \'6_06\', \'6_07\', \'6_08\', \'6_09\', \'6_10\', \'6_11\', \'6_12\', \'6_13\', \'6_14\', \'6_15\', \'6_16\', \'6_17\', \'6_18\', \'6_19\', \'6_20\', \'6_21\', \'6_22\', \'6_23\', \'7_00\', \'7_01\', \'7_02\', \'7_03\', \'7_04\', \'7_05\', \'7_06\', \'7_07\', \'7_08\', \'7_09\', \'7_10\', \'7_11\', \'7_12\', \'7_13\', \'7_14\', \'7_15\', \'7_16\', \'7_17\', \'7_18\', \'7_19\', \'7_20\', \'7_21\', \'7_22\', \'7_23\'])")__key__.pkl'

Any ideas on how to work around this issue?

@samuelefiorini

This comment was marked as outdated.

@samuelefiorini
Copy link
Author

In order to reproduce the error above, see the following minimal script.

import warnings

warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd

from greykite.common.data_loader import DataLoader
from greykite.framework.templates.autogen.forecast_config import ForecastConfig
from greykite.framework.templates.autogen.forecast_config import MetadataParam
from greykite.framework.templates.autogen.forecast_config import ModelComponentsParam
from greykite.framework.templates.forecaster import Forecaster
from greykite.framework.templates.model_templates import ModelTemplateEnum

pd.options.plotting.backend = 'plotly'

# Defines inputs
df = DataLoader().load_bikesharing().tail(24*90)  # Input time series (pandas.DataFrame)
df['ts'] = pd.to_datetime(df['ts'])

forecast_horizon = 24*2

df.loc[df.index[-forecast_horizon:], 'count'] = np.nan

print(df.head())
print(df.tail())

config = ForecastConfig(
     metadata_param=MetadataParam(time_col="ts", value_col="count"),  # Column names in `df_train`
     model_template=ModelTemplateEnum.AUTO.name,  # AUTO model configuration
     forecast_horizon=forecast_horizon,   # Forecasts all the missing steps
     model_components_param=ModelComponentsParam(regressors={"regressor_cols": ['tmin', 'tmax', 'pn']}),
     coverage=0.95,         # 95% prediction intervals
)

# Creates forecasts
forecaster = Forecaster()
result = forecaster.run_forecast_config(df=df, config=config)

forecaster.dump_forecast_result(
    '/tmp/forecaster',
    object_name="object",
    dump_design_info=False,
    overwrite_exist_dir=True
)

# Recreate Forecaster
new_forecaster = Forecaster()
new_forecaster.load_forecast_result(
    '/tmp/forecaster',
    load_design_info=False
)
new_result = new_forecaster.forecast_result

new_pred = new_result.model.predict(df.rename(columns={'count': 'y'}))

print(new_pred)

which gives the following output

 date                  ts  count  tmin  tmax   pn
76261  2019-06-03 2019-06-03 01:00:00   35.0  11.7  23.3  0.0
76262  2019-06-03 2019-06-03 02:00:00   20.0  11.7  23.3  0.0
76263  2019-06-03 2019-06-03 03:00:00    9.0  11.7  23.3  0.0
76264  2019-06-03 2019-06-03 04:00:00   14.0  11.7  23.3  0.0
76265  2019-06-03 2019-06-03 05:00:00   37.0  11.7  23.3  0.0
             date                  ts  count  tmin  tmax   pn
78416  2019-08-31 2019-08-31 20:00:00    NaN  17.8  31.1  0.0
78417  2019-08-31 2019-08-31 21:00:00    NaN  17.8  31.1  0.0
78418  2019-08-31 2019-08-31 22:00:00    NaN  17.8  31.1  0.0
78419  2019-08-31 2019-08-31 23:00:00    NaN  17.8  31.1  0.0
78420  2019-09-01 2019-09-01 00:00:00    NaN  21.1  28.3  0.0
Fitting 3 folds for each of 1 candidates, totalling 3 fits
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-15-06ddd6580090> in <module>
----> 1 new_pred = new_result.model.predict(df.rename(columns={'count': 'y'}))

/opt/conda/lib/python3.7/site-packages/sklearn/utils/metaestimators.py in <lambda>(*args, **kwargs)
    111 
    112             # lambda, but not partial, allows help() to work with update_wrapper
--> 113             out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)  # noqa
    114         else:
    115 

/opt/conda/lib/python3.7/site-packages/sklearn/pipeline.py in predict(self, X, **predict_params)
    468         for _, name, transform in self._iter(with_final=False):
    469             Xt = transform.transform(Xt)
--> 470         return self.steps[-1][1].predict(Xt, **predict_params)
    471 
    472     @available_if(_final_estimator_has("fit_predict"))

/opt/conda/lib/python3.7/site-packages/greykite/sklearn/estimator/base_silverkite_estimator.py in predict(self, X, y)
    365             trained_model=self.model_dict,
    366             past_df=self.past_df,
--> 367             new_external_regressor_df=None)  # regressors are included in X
    368         pred_df = pred_res["fut_df"]
    369         x_mat = pred_res["x_mat"]

/opt/conda/lib/python3.7/site-packages/greykite/algo/forecast/silverkite/forecast_silverkite.py in predict(self, fut_df, trained_model, freq, past_df, new_external_regressor_df, include_err, force_no_sim, simulation_num, fast_simulation, na_fill_func)
   2266                     new_external_regressor_df=None,
   2267                     time_features_ready=False,
-> 2268                     regressors_ready=True)
   2269                 fut_df0 = pred_res["fut_df"]
   2270                 x_mat0 = pred_res["x_mat"]

/opt/conda/lib/python3.7/site-packages/greykite/algo/forecast/silverkite/forecast_silverkite.py in predict_no_sim(self, fut_df, trained_model, past_df, new_external_regressor_df, time_features_ready, regressors_ready)
   1226             pred_res = predict_ml_with_uncertainty(
   1227                 fut_df=features_df_fut,
-> 1228                 trained_model=trained_model)
   1229             fut_df = pred_res["fut_df"]
   1230             x_mat = pred_res["x_mat"]

/opt/conda/lib/python3.7/site-packages/greykite/algo/common/ml_models.py in predict_ml_with_uncertainty(fut_df, trained_model)
    558     pred_res = predict_ml(
    559         fut_df=fut_df,
--> 560         trained_model=trained_model)
    561 
    562     y_pred = pred_res["fut_df"][y_col]

/opt/conda/lib/python3.7/site-packages/greykite/algo/common/ml_models.py in predict_ml(fut_df, trained_model)
    501     y_col = trained_model["y_col"]
    502     ml_model = trained_model["ml_model"]
--> 503     x_design_info = trained_model["x_design_info"]
    504     min_admissible_value = trained_model["min_admissible_value"]
    505     max_admissible_value = trained_model["max_admissible_value"]

KeyError: 'x_design_info'

This is a serious issue, as it makes Greykite unusable when predictions needs to be made using a mdel restored from disk.

@samuelefiorini
Copy link
Author

I have discovered a solution to address the issue. With the help of @andreaschiappacasse and his insightful intuition, we have successfully implemented a slightly modified version of the model dump & load process.

This modification prevents the system from generating pickle files with excessively long names. As a result, we can now dump models with the design info matrix and subsequently load them to perform predictions.

Please refer to the implementation below for more details

def dump_obj(obj, dir_name, obj_name="obj", dump_design_info=True, overwrite_exist_dir=False, top_level=True):
    """See `greykite.framework.templates.pickle_utils.dump_obj`."""
    # Checks if to dump design info.
    if (not dump_design_info) and (isinstance(obj, DesignInfo) or (isinstance(obj, str) and obj == "x_design_info")):
        return

    # Checks if directory already exists.
    if top_level:
        dir_already_exist = os.path.exists(dir_name)
        if dir_already_exist:
            if not overwrite_exist_dir:
                raise FileExistsError(
                    "The directory already exists. "
                    "Please either specify a new directory or "
                    "set overwrite_exist_dir to True to overwrite it."
                )
            else:
                if os.path.isdir(dir_name):
                    # dir exists as a directory.
                    shutil.rmtree(dir_name)
                else:
                    # dir exists as a file.
                    os.remove(dir_name)

    # Creates the directory.
    # None top-level may write to the same directory,
    # so we allow existing directory in this case.
    try:
        os.mkdir(dir_name)
    except FileExistsError:
        pass

    # Start dumping recursively.
    try:
        # Attempts to directly dump the object.
        dill.dump(obj, open(os.path.join(dir_name, f"{obj_name}.pkl"), "wb"))
    except NotImplementedError:
        # Direct dumping fails.
        # Removed the failed file.
        try:
            os.remove(os.path.join(dir_name, f"{obj_name}.pkl"))
        except FileNotFoundError:
            pass
        # Attempts to do recursive dumping depending on the object type.
        if isinstance(obj, OrderedDict):
            # For OrderedDict (there are a lot in `pasty.design_info.DesignInfo`),
            # recursively dumps the keys and values, because keys can be class instances
            # and unpicklable, too.
            # The keys and values have index number appended to the front,
            # so the order is kept.
            dill.dump("ordered_dict", open(os.path.join(dir_name, f"{obj_name}.type"), "wb"))  # type "ordered_dict"
            for i, (key, value) in enumerate(obj.items()):
                # name = str(key) # this is how it used to be in the "name too-long" version
                name = f"{i}_{str(hash(key))}"  # but we actually don't need to keep the key in name
                dump_obj(
                    key,
                    os.path.join(dir_name, obj_name),
                    f"{name}__key__",
                    dump_design_info=dump_design_info,
                    top_level=False,
                )
                dump_obj(
                    value,
                    os.path.join(dir_name, obj_name),
                    f"{name}__value__",
                    dump_design_info=dump_design_info,
                    top_level=False,
                )
        elif isinstance(obj, dict):
            # For regular dictionary,
            # recursively dumps the keys and values, because keys can be class instances
            # and unpicklable, too.
            # The order is not important.
            dill.dump("dict", open(os.path.join(dir_name, f"{obj_name}.type"), "wb"))  # type "dict"
            for key, value in obj.items():
                # name = str(key) # this is how it used to be in the "name too-long" version
                name = str(hash(key))  # but we actually don't need to keep the key in name
                dump_obj(
                    key,
                    os.path.join(dir_name, obj_name),
                    f"{name}__key__",
                    dump_design_info=dump_design_info,
                    top_level=False,
                )
                dump_obj(
                    value,
                    os.path.join(dir_name, obj_name),
                    f"{name}__value__",
                    dump_design_info=dump_design_info,
                    top_level=False,
                )
        elif isinstance(obj, (list, tuple)):
            # For list and tuples,
            # recursively dumps the elements.
            # The names have index number appended to the front,
            # so the order is kept.
            dill.dump(
                type(obj).__name__, open(os.path.join(dir_name, f"{obj_name}.type"), "wb")
            )  # type "list"/"tuple"
            for i, value in enumerate(obj):
                dump_obj(
                    value,
                    os.path.join(dir_name, obj_name),
                    f"{i}_key",
                    dump_design_info=dump_design_info,
                    top_level=False,
                )
        elif hasattr(obj, "__class__") and not isinstance(obj, type):
            # For class instance,
            # recursively dumps the attributes.
            dill.dump(obj.__class__, open(os.path.join(dir_name, f"{obj_name}.type"), "wb"))  # type is class itself
            for key, value in obj.__dict__.items():
                dump_obj(
                    value, os.path.join(dir_name, obj_name), key, dump_design_info=dump_design_info, top_level=False
                )
        else:
            # Other unrecognized unpicklable types, not common.
            print(f"I Don't recognize type {type(obj)}")


def load_obj(dir_name, obj=None, load_design_info=True):
        """See `greykite.framework.templates.pickle_utils.load_obj`."""
    # Checks if to load design info.
    if (not load_design_info) and (isinstance(obj, type) and obj == DesignInfo):
        return None

    # Gets file names in the level.
    files = os.listdir(dir_name)
    if not files:
        raise ValueError("dir is empty!")

    # Gets the type files if any.
    # Stores in a dictionary with key being the name and value being the loaded value.
    obj_types = {
        file.split(".")[0]: dill.load(open(os.path.join(dir_name, file), "rb")) for file in files if ".type" in file
    }

    # Gets directories and pickled files.
    # Every type must have a directory with the same name.
    directories = [file for file in files if os.path.isdir(os.path.join(dir_name, file))]
    if not all([directory in obj_types for directory in directories]):
        raise ValueError("type and directories do not match.")
    pickles = [file for file in files if ".pkl" in file]

    # Starts loading objects
    if obj is None:
        # obj is None indicates this is the top level directory.
        # This directory can either have 1 .pkl file, or 1 .type file associated with the directory of same name.
        if not obj_types:
            # The only 1 .pkl file case.
            if len(files) > 1:
                raise ValueError("Multiple elements found in top level.")
            return dill.load(open(os.path.join(dir_name, files[0]), "rb"))
        else:
            # The .type + dir case.
            if len(obj_types) > 1:
                raise ValueError("Multiple elements found in top level")
            obj_name = list(obj_types.keys())[0]
            obj_type = obj_types[obj_name]
            return load_obj(os.path.join(dir_name, obj_name), obj_type, load_design_info=load_design_info)
    else:
        # If obj is not None, does recursive loading depending on the obj type.
        if obj in ("list", "tuple"):
            # Object is list or tuple.
            # Fetches each element according to the number index to preserve orders.
            result = []
            # Order index is a number appended to the front.
            elements = sorted(pickles + directories, key=lambda x: int(x.split("_")[0]))
            # Recursively loads elements.
            for element in elements:
                if ".pkl" in element:
                    result.append(dill.load(open(os.path.join(dir_name, element), "rb")))
                else:
                    result.append(
                        load_obj(
                            os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
                        )
                    )
            if obj == "tuple":
                result = tuple(result)
            return result
        elif obj == "dict":
            # Object is a dictionary.
            # Fetches keys and values recursively.
            result = {}
            elements = pickles + directories
            keys = [element for element in elements if "__key__" in element]
            values = [element for element in elements if "__value__" in element]
            # Iterates through keys and finds the corresponding values.
            for element in keys:
                if ".pkl" in element:
                    key = dill.load(open(os.path.join(dir_name, element), "rb"))
                else:
                    key = load_obj(
                        os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
                    )
                # Value name could be either with .pkl or a directory.
                value_name = element.replace("__key__", "__value__")
                if ".pkl" in value_name:
                    value_name_alt = value_name.replace(".pkl", "")
                else:
                    value_name_alt = value_name + ".pkl"
                # Checks if value name is in the dir.
                if (value_name not in values) and (value_name_alt not in values):
                    raise FileNotFoundError(f"Value not found for key {key}.")
                value_name = value_name if value_name in values else value_name_alt
                # Gets the value.
                if ".pkl" in value_name:
                    value = dill.load(open(os.path.join(dir_name, value_name), "rb"))
                else:
                    value = load_obj(
                        os.path.join(dir_name, value_name), obj_types[value_name], load_design_info=load_design_info
                    )
                # Sets the key, value pair.
                result[key] = value
            return result
        elif obj == "ordered_dict":
            # Object is OrderedDict.
            # Fetches keys and values according to the number index to preserve orders.
            result = OrderedDict()
            # Order index is a number appended to the front.
            elements = sorted(pickles + directories, key=lambda x: int(x.split("_")[0]))
            # elements = pickles + directories
            keys = [element for element in elements if "__key__" in element]
            values = [element for element in elements if "__value__" in element]
            # Iterates through keys and finds the corresponding values.
            for element in keys:
                if ".pkl" in element:
                    key = dill.load(open(os.path.join(dir_name, element), "rb"))
                else:
                    key = load_obj(
                        os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
                    )
                value_name = element.replace("__key__", "__value__")
                # Value name could be either with .pkl or a directory.
                if ".pkl" in value_name:
                    value_name_alt = value_name.replace(".pkl", "")
                else:
                    value_name_alt = value_name + ".pkl"
                # Checks if value name is in the dir.
                if (value_name not in values) and (value_name_alt not in values):
                    raise FileNotFoundError(f"Value not found for key {key}.")
                value_name = value_name if value_name in values else value_name_alt
                # Gets the value.
                if ".pkl" in value_name:
                    value = dill.load(open(os.path.join(dir_name, value_name), "rb"))
                else:
                    value = load_obj(
                        os.path.join(dir_name, value_name), obj_types[value_name], load_design_info=load_design_info
                    )
                # Sets the key, value pair.
                result[key] = value
            return result
        elif inspect.isclass(obj):
            # Object is a class instance.
            # Creates the class instance and sets the attributes.
            # Some class has required args during initialization,
            # these args are pulled from attributes.
            init_params = list(inspect.signature(obj.__init__).parameters)  # init args
            elements = pickles + directories
            # Gets the attribute names and their values in a dictionary.
            values = {}
            for element in elements:
                if ".pkl" in element:
                    values[element.split(".")[0]] = dill.load(open(os.path.join(dir_name, element), "rb"))
                else:
                    values[element] = load_obj(
                        os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
                    )
            # Gets the init args from values.
            init_dict = {key: value for key, value in values.items() if key in init_params}
            # Some attributes has a "_" at the beginning.
            init_dict.update(
                {key[1:]: value for key, value in values.items() if (key[1:] in init_params and key[0] == "_")}
            )
            # ``design_info`` does not have column_names attribute,
            # which is required during init.
            # The column_names param is pulled from the column_name_indexes attribute.
            # This can be omitted once we allow dumping @property attributes.
            if "column_names" in init_params:
                init_dict["column_names"] = values["column_name_indexes"].keys()
            # Creates the instance.
            result = obj(**init_dict)
            # Sets the attributes.
            for key, value in values.items():
                setattr(result, key, value)
            return result
        else:
            # Raises an error if the object is not recognized.
            # This typically does not happen when the source file is dumped
            # with the `dump_obj` function.
            raise ValueError(f"Object {obj} is not recognized.")

@pjgaudre
Copy link

Thank you for raising this issue and providing an alternative. My colleague @sayanpatra was working on model dumps this quarter. I'll ask him to take a look at it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants