-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
design_info
is needed to make predictions on new data
#121
Comments
This comment was marked as outdated.
This comment was marked as outdated.
In order to reproduce the error above, see the following minimal script. import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from greykite.common.data_loader import DataLoader
from greykite.framework.templates.autogen.forecast_config import ForecastConfig
from greykite.framework.templates.autogen.forecast_config import MetadataParam
from greykite.framework.templates.autogen.forecast_config import ModelComponentsParam
from greykite.framework.templates.forecaster import Forecaster
from greykite.framework.templates.model_templates import ModelTemplateEnum
pd.options.plotting.backend = 'plotly'
# Defines inputs
df = DataLoader().load_bikesharing().tail(24*90) # Input time series (pandas.DataFrame)
df['ts'] = pd.to_datetime(df['ts'])
forecast_horizon = 24*2
df.loc[df.index[-forecast_horizon:], 'count'] = np.nan
print(df.head())
print(df.tail())
config = ForecastConfig(
metadata_param=MetadataParam(time_col="ts", value_col="count"), # Column names in `df_train`
model_template=ModelTemplateEnum.AUTO.name, # AUTO model configuration
forecast_horizon=forecast_horizon, # Forecasts all the missing steps
model_components_param=ModelComponentsParam(regressors={"regressor_cols": ['tmin', 'tmax', 'pn']}),
coverage=0.95, # 95% prediction intervals
)
# Creates forecasts
forecaster = Forecaster()
result = forecaster.run_forecast_config(df=df, config=config)
forecaster.dump_forecast_result(
'/tmp/forecaster',
object_name="object",
dump_design_info=False,
overwrite_exist_dir=True
)
# Recreate Forecaster
new_forecaster = Forecaster()
new_forecaster.load_forecast_result(
'/tmp/forecaster',
load_design_info=False
)
new_result = new_forecaster.forecast_result
new_pred = new_result.model.predict(df.rename(columns={'count': 'y'}))
print(new_pred) which gives the following output date ts count tmin tmax pn
76261 2019-06-03 2019-06-03 01:00:00 35.0 11.7 23.3 0.0
76262 2019-06-03 2019-06-03 02:00:00 20.0 11.7 23.3 0.0
76263 2019-06-03 2019-06-03 03:00:00 9.0 11.7 23.3 0.0
76264 2019-06-03 2019-06-03 04:00:00 14.0 11.7 23.3 0.0
76265 2019-06-03 2019-06-03 05:00:00 37.0 11.7 23.3 0.0
date ts count tmin tmax pn
78416 2019-08-31 2019-08-31 20:00:00 NaN 17.8 31.1 0.0
78417 2019-08-31 2019-08-31 21:00:00 NaN 17.8 31.1 0.0
78418 2019-08-31 2019-08-31 22:00:00 NaN 17.8 31.1 0.0
78419 2019-08-31 2019-08-31 23:00:00 NaN 17.8 31.1 0.0
78420 2019-09-01 2019-09-01 00:00:00 NaN 21.1 28.3 0.0
Fitting 3 folds for each of 1 candidates, totalling 3 fits ---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-15-06ddd6580090> in <module>
----> 1 new_pred = new_result.model.predict(df.rename(columns={'count': 'y'}))
/opt/conda/lib/python3.7/site-packages/sklearn/utils/metaestimators.py in <lambda>(*args, **kwargs)
111
112 # lambda, but not partial, allows help() to work with update_wrapper
--> 113 out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs) # noqa
114 else:
115
/opt/conda/lib/python3.7/site-packages/sklearn/pipeline.py in predict(self, X, **predict_params)
468 for _, name, transform in self._iter(with_final=False):
469 Xt = transform.transform(Xt)
--> 470 return self.steps[-1][1].predict(Xt, **predict_params)
471
472 @available_if(_final_estimator_has("fit_predict"))
/opt/conda/lib/python3.7/site-packages/greykite/sklearn/estimator/base_silverkite_estimator.py in predict(self, X, y)
365 trained_model=self.model_dict,
366 past_df=self.past_df,
--> 367 new_external_regressor_df=None) # regressors are included in X
368 pred_df = pred_res["fut_df"]
369 x_mat = pred_res["x_mat"]
/opt/conda/lib/python3.7/site-packages/greykite/algo/forecast/silverkite/forecast_silverkite.py in predict(self, fut_df, trained_model, freq, past_df, new_external_regressor_df, include_err, force_no_sim, simulation_num, fast_simulation, na_fill_func)
2266 new_external_regressor_df=None,
2267 time_features_ready=False,
-> 2268 regressors_ready=True)
2269 fut_df0 = pred_res["fut_df"]
2270 x_mat0 = pred_res["x_mat"]
/opt/conda/lib/python3.7/site-packages/greykite/algo/forecast/silverkite/forecast_silverkite.py in predict_no_sim(self, fut_df, trained_model, past_df, new_external_regressor_df, time_features_ready, regressors_ready)
1226 pred_res = predict_ml_with_uncertainty(
1227 fut_df=features_df_fut,
-> 1228 trained_model=trained_model)
1229 fut_df = pred_res["fut_df"]
1230 x_mat = pred_res["x_mat"]
/opt/conda/lib/python3.7/site-packages/greykite/algo/common/ml_models.py in predict_ml_with_uncertainty(fut_df, trained_model)
558 pred_res = predict_ml(
559 fut_df=fut_df,
--> 560 trained_model=trained_model)
561
562 y_pred = pred_res["fut_df"][y_col]
/opt/conda/lib/python3.7/site-packages/greykite/algo/common/ml_models.py in predict_ml(fut_df, trained_model)
501 y_col = trained_model["y_col"]
502 ml_model = trained_model["ml_model"]
--> 503 x_design_info = trained_model["x_design_info"]
504 min_admissible_value = trained_model["min_admissible_value"]
505 max_admissible_value = trained_model["max_admissible_value"]
KeyError: 'x_design_info' This is a serious issue, as it makes Greykite unusable when predictions needs to be made using a mdel restored from disk. |
I have discovered a solution to address the issue. With the help of @andreaschiappacasse and his insightful intuition, we have successfully implemented a slightly modified version of the model dump & load process. This modification prevents the system from generating pickle files with excessively long names. As a result, we can now dump models with the design info matrix and subsequently load them to perform predictions. Please refer to the implementation below for more details def dump_obj(obj, dir_name, obj_name="obj", dump_design_info=True, overwrite_exist_dir=False, top_level=True):
"""See `greykite.framework.templates.pickle_utils.dump_obj`."""
# Checks if to dump design info.
if (not dump_design_info) and (isinstance(obj, DesignInfo) or (isinstance(obj, str) and obj == "x_design_info")):
return
# Checks if directory already exists.
if top_level:
dir_already_exist = os.path.exists(dir_name)
if dir_already_exist:
if not overwrite_exist_dir:
raise FileExistsError(
"The directory already exists. "
"Please either specify a new directory or "
"set overwrite_exist_dir to True to overwrite it."
)
else:
if os.path.isdir(dir_name):
# dir exists as a directory.
shutil.rmtree(dir_name)
else:
# dir exists as a file.
os.remove(dir_name)
# Creates the directory.
# None top-level may write to the same directory,
# so we allow existing directory in this case.
try:
os.mkdir(dir_name)
except FileExistsError:
pass
# Start dumping recursively.
try:
# Attempts to directly dump the object.
dill.dump(obj, open(os.path.join(dir_name, f"{obj_name}.pkl"), "wb"))
except NotImplementedError:
# Direct dumping fails.
# Removed the failed file.
try:
os.remove(os.path.join(dir_name, f"{obj_name}.pkl"))
except FileNotFoundError:
pass
# Attempts to do recursive dumping depending on the object type.
if isinstance(obj, OrderedDict):
# For OrderedDict (there are a lot in `pasty.design_info.DesignInfo`),
# recursively dumps the keys and values, because keys can be class instances
# and unpicklable, too.
# The keys and values have index number appended to the front,
# so the order is kept.
dill.dump("ordered_dict", open(os.path.join(dir_name, f"{obj_name}.type"), "wb")) # type "ordered_dict"
for i, (key, value) in enumerate(obj.items()):
# name = str(key) # this is how it used to be in the "name too-long" version
name = f"{i}_{str(hash(key))}" # but we actually don't need to keep the key in name
dump_obj(
key,
os.path.join(dir_name, obj_name),
f"{name}__key__",
dump_design_info=dump_design_info,
top_level=False,
)
dump_obj(
value,
os.path.join(dir_name, obj_name),
f"{name}__value__",
dump_design_info=dump_design_info,
top_level=False,
)
elif isinstance(obj, dict):
# For regular dictionary,
# recursively dumps the keys and values, because keys can be class instances
# and unpicklable, too.
# The order is not important.
dill.dump("dict", open(os.path.join(dir_name, f"{obj_name}.type"), "wb")) # type "dict"
for key, value in obj.items():
# name = str(key) # this is how it used to be in the "name too-long" version
name = str(hash(key)) # but we actually don't need to keep the key in name
dump_obj(
key,
os.path.join(dir_name, obj_name),
f"{name}__key__",
dump_design_info=dump_design_info,
top_level=False,
)
dump_obj(
value,
os.path.join(dir_name, obj_name),
f"{name}__value__",
dump_design_info=dump_design_info,
top_level=False,
)
elif isinstance(obj, (list, tuple)):
# For list and tuples,
# recursively dumps the elements.
# The names have index number appended to the front,
# so the order is kept.
dill.dump(
type(obj).__name__, open(os.path.join(dir_name, f"{obj_name}.type"), "wb")
) # type "list"/"tuple"
for i, value in enumerate(obj):
dump_obj(
value,
os.path.join(dir_name, obj_name),
f"{i}_key",
dump_design_info=dump_design_info,
top_level=False,
)
elif hasattr(obj, "__class__") and not isinstance(obj, type):
# For class instance,
# recursively dumps the attributes.
dill.dump(obj.__class__, open(os.path.join(dir_name, f"{obj_name}.type"), "wb")) # type is class itself
for key, value in obj.__dict__.items():
dump_obj(
value, os.path.join(dir_name, obj_name), key, dump_design_info=dump_design_info, top_level=False
)
else:
# Other unrecognized unpicklable types, not common.
print(f"I Don't recognize type {type(obj)}")
def load_obj(dir_name, obj=None, load_design_info=True):
"""See `greykite.framework.templates.pickle_utils.load_obj`."""
# Checks if to load design info.
if (not load_design_info) and (isinstance(obj, type) and obj == DesignInfo):
return None
# Gets file names in the level.
files = os.listdir(dir_name)
if not files:
raise ValueError("dir is empty!")
# Gets the type files if any.
# Stores in a dictionary with key being the name and value being the loaded value.
obj_types = {
file.split(".")[0]: dill.load(open(os.path.join(dir_name, file), "rb")) for file in files if ".type" in file
}
# Gets directories and pickled files.
# Every type must have a directory with the same name.
directories = [file for file in files if os.path.isdir(os.path.join(dir_name, file))]
if not all([directory in obj_types for directory in directories]):
raise ValueError("type and directories do not match.")
pickles = [file for file in files if ".pkl" in file]
# Starts loading objects
if obj is None:
# obj is None indicates this is the top level directory.
# This directory can either have 1 .pkl file, or 1 .type file associated with the directory of same name.
if not obj_types:
# The only 1 .pkl file case.
if len(files) > 1:
raise ValueError("Multiple elements found in top level.")
return dill.load(open(os.path.join(dir_name, files[0]), "rb"))
else:
# The .type + dir case.
if len(obj_types) > 1:
raise ValueError("Multiple elements found in top level")
obj_name = list(obj_types.keys())[0]
obj_type = obj_types[obj_name]
return load_obj(os.path.join(dir_name, obj_name), obj_type, load_design_info=load_design_info)
else:
# If obj is not None, does recursive loading depending on the obj type.
if obj in ("list", "tuple"):
# Object is list or tuple.
# Fetches each element according to the number index to preserve orders.
result = []
# Order index is a number appended to the front.
elements = sorted(pickles + directories, key=lambda x: int(x.split("_")[0]))
# Recursively loads elements.
for element in elements:
if ".pkl" in element:
result.append(dill.load(open(os.path.join(dir_name, element), "rb")))
else:
result.append(
load_obj(
os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
)
)
if obj == "tuple":
result = tuple(result)
return result
elif obj == "dict":
# Object is a dictionary.
# Fetches keys and values recursively.
result = {}
elements = pickles + directories
keys = [element for element in elements if "__key__" in element]
values = [element for element in elements if "__value__" in element]
# Iterates through keys and finds the corresponding values.
for element in keys:
if ".pkl" in element:
key = dill.load(open(os.path.join(dir_name, element), "rb"))
else:
key = load_obj(
os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
)
# Value name could be either with .pkl or a directory.
value_name = element.replace("__key__", "__value__")
if ".pkl" in value_name:
value_name_alt = value_name.replace(".pkl", "")
else:
value_name_alt = value_name + ".pkl"
# Checks if value name is in the dir.
if (value_name not in values) and (value_name_alt not in values):
raise FileNotFoundError(f"Value not found for key {key}.")
value_name = value_name if value_name in values else value_name_alt
# Gets the value.
if ".pkl" in value_name:
value = dill.load(open(os.path.join(dir_name, value_name), "rb"))
else:
value = load_obj(
os.path.join(dir_name, value_name), obj_types[value_name], load_design_info=load_design_info
)
# Sets the key, value pair.
result[key] = value
return result
elif obj == "ordered_dict":
# Object is OrderedDict.
# Fetches keys and values according to the number index to preserve orders.
result = OrderedDict()
# Order index is a number appended to the front.
elements = sorted(pickles + directories, key=lambda x: int(x.split("_")[0]))
# elements = pickles + directories
keys = [element for element in elements if "__key__" in element]
values = [element for element in elements if "__value__" in element]
# Iterates through keys and finds the corresponding values.
for element in keys:
if ".pkl" in element:
key = dill.load(open(os.path.join(dir_name, element), "rb"))
else:
key = load_obj(
os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
)
value_name = element.replace("__key__", "__value__")
# Value name could be either with .pkl or a directory.
if ".pkl" in value_name:
value_name_alt = value_name.replace(".pkl", "")
else:
value_name_alt = value_name + ".pkl"
# Checks if value name is in the dir.
if (value_name not in values) and (value_name_alt not in values):
raise FileNotFoundError(f"Value not found for key {key}.")
value_name = value_name if value_name in values else value_name_alt
# Gets the value.
if ".pkl" in value_name:
value = dill.load(open(os.path.join(dir_name, value_name), "rb"))
else:
value = load_obj(
os.path.join(dir_name, value_name), obj_types[value_name], load_design_info=load_design_info
)
# Sets the key, value pair.
result[key] = value
return result
elif inspect.isclass(obj):
# Object is a class instance.
# Creates the class instance and sets the attributes.
# Some class has required args during initialization,
# these args are pulled from attributes.
init_params = list(inspect.signature(obj.__init__).parameters) # init args
elements = pickles + directories
# Gets the attribute names and their values in a dictionary.
values = {}
for element in elements:
if ".pkl" in element:
values[element.split(".")[0]] = dill.load(open(os.path.join(dir_name, element), "rb"))
else:
values[element] = load_obj(
os.path.join(dir_name, element), obj_types[element], load_design_info=load_design_info
)
# Gets the init args from values.
init_dict = {key: value for key, value in values.items() if key in init_params}
# Some attributes has a "_" at the beginning.
init_dict.update(
{key[1:]: value for key, value in values.items() if (key[1:] in init_params and key[0] == "_")}
)
# ``design_info`` does not have column_names attribute,
# which is required during init.
# The column_names param is pulled from the column_name_indexes attribute.
# This can be omitted once we allow dumping @property attributes.
if "column_names" in init_params:
init_dict["column_names"] = values["column_name_indexes"].keys()
# Creates the instance.
result = obj(**init_dict)
# Sets the attributes.
for key, value in values.items():
setattr(result, key, value)
return result
else:
# Raises an error if the object is not recognized.
# This typically does not happen when the source file is dumped
# with the `dump_obj` function.
raise ValueError(f"Object {obj} is not recognized.") |
Thank you for raising this issue and providing an alternative. My colleague @sayanpatra was working on model dumps this quarter. I'll ask him to take a look at it. |
According to the documentation
However trying to make predictions on new data using a model restored via
forecaster.load_forecast_result(path, load_design_info=False)
leads to the following error.This totally makes sense when looking at
greykite.algo.common.ml_models.predict_ml
as the variablex_design_info
is used bypatsy
to build the design matrix (see here).On the other hand, dumping
design_info
does not only imply dealing with a bigger artifact, but may be impossible due to system limitations on the generated filename.As an example, this is what happens in my case.
Any ideas on how to work around this issue?
The text was updated successfully, but these errors were encountered: