From 5190332c7ecdc274e8d4af5e5ef8a4c4f0ee6963 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Wed, 26 Jun 2024 18:34:00 +0800 Subject: [PATCH] Add some misc features. (#1816) * Normal mod * Black linting * Linting --- .../benchmarks_dynamic/DDG-DA/workflow.py | 9 +- .../baseline/rolling_benchmark.py | 9 +- qlib/contrib/meta/data_selection/dataset.py | 34 ++++++-- qlib/contrib/meta/data_selection/model.py | 12 ++- qlib/contrib/meta/data_selection/utils.py | 9 +- qlib/contrib/model/linear.py | 1 + qlib/contrib/model/pytorch_gru.py | 84 ++++++++++++------- qlib/contrib/report/data/ana.py | 16 ++++ qlib/contrib/report/data/base.py | 18 ++++ qlib/contrib/report/utils.py | 6 +- qlib/contrib/rolling/base.py | 35 ++++++-- qlib/contrib/rolling/ddgda.py | 72 ++++++++++++---- qlib/model/meta/task.py | 3 + qlib/utils/mod.py | 8 +- qlib/workflow/cli.py | 48 +++++++++-- 15 files changed, 289 insertions(+), 75 deletions(-) diff --git a/examples/benchmarks_dynamic/DDG-DA/workflow.py b/examples/benchmarks_dynamic/DDG-DA/workflow.py index 7593fe374f..8209e0e906 100644 --- a/examples/benchmarks_dynamic/DDG-DA/workflow.py +++ b/examples/benchmarks_dynamic/DDG-DA/workflow.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import os from pathlib import Path from typing import Union @@ -35,6 +36,10 @@ def __init__(self, conf_path: Union[str, Path] = DEFAULT_CONF, horizon=20, **kwa if __name__ == "__main__": - GetData().qlib_data(exists_skip=True) - auto_init() + kwargs = {} + if os.environ.get("PROVIDER_URI", "") == "": + GetData().qlib_data(exists_skip=True) + else: + kwargs["provider_uri"] = os.environ["PROVIDER_URI"] + auto_init(**kwargs) fire.Fire(DDGDABench) diff --git a/examples/benchmarks_dynamic/baseline/rolling_benchmark.py b/examples/benchmarks_dynamic/baseline/rolling_benchmark.py index 1ce30ef8a7..02b7ed4650 100644 --- a/examples/benchmarks_dynamic/baseline/rolling_benchmark.py +++ b/examples/benchmarks_dynamic/baseline/rolling_benchmark.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import os from pathlib import Path from typing import Union @@ -31,6 +32,10 @@ def __init__(self, conf_path: Union[str, Path] = DEFAULT_CONF, horizon=20, **kwa if __name__ == "__main__": - GetData().qlib_data(exists_skip=True) - auto_init() + kwargs = {} + if os.environ.get("PROVIDER_URI", "") == "": + GetData().qlib_data(exists_skip=True) + else: + kwargs["provider_uri"] = os.environ["PROVIDER_URI"] + auto_init(**kwargs) fire.Fire(RollingBenchmark) diff --git a/qlib/contrib/meta/data_selection/dataset.py b/qlib/contrib/meta/data_selection/dataset.py index 9349a12fe5..58e160f110 100644 --- a/qlib/contrib/meta/data_selection/dataset.py +++ b/qlib/contrib/meta/data_selection/dataset.py @@ -243,7 +243,7 @@ def __init__( trunc_days: int = None, rolling_ext_days: int = 0, exp_name: Union[str, InternalData], - segments: Union[Dict[Text, Tuple], float], + segments: Union[Dict[Text, Tuple], float, str], hist_step_n: int = 10, task_mode: str = MetaTask.PROC_MODE_FULL, fill_method: str = "max", @@ -271,12 +271,16 @@ def __init__( - str: the name of the experiment to store the performance of data - InternalData: a prepared internal data segments: Union[Dict[Text, Tuple], float] - the segments to divide data - both left and right + if the segment is a Dict + the segments to divide data + both left and right are included if segments is a float: the float represents the percentage of data for training + if segments is a string: + it will try its best to put its data in training and ensure that the date `segments` is in the test set hist_step_n: int length of historical steps for the meta infomation + Number of steps of the data similarity information task_mode : str Please refer to the docs of MetaTask """ @@ -383,10 +387,30 @@ def _prepare_seg(self, segment: Text) -> List[MetaTask]: if isinstance(self.segments, float): train_task_n = int(len(self.meta_task_l) * self.segments) if segment == "train": - return self.meta_task_l[:train_task_n] + train_tasks = self.meta_task_l[:train_task_n] + get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}") + return train_tasks elif segment == "test": - return self.meta_task_l[train_task_n:] + test_tasks = self.meta_task_l[train_task_n:] + get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}") + return test_tasks else: raise NotImplementedError(f"This type of input is not supported") + elif isinstance(self.segments, str): + train_tasks = [] + test_tasks = [] + for t in self.meta_task_l: + test_end = t.task["dataset"]["kwargs"]["segments"]["test"][1] + if test_end is None or pd.Timestamp(test_end) < pd.Timestamp(self.segments): + train_tasks.append(t) + else: + test_tasks.append(t) + get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}") + get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}") + if segment == "train": + return train_tasks + elif segment == "test": + return test_tasks + raise NotImplementedError(f"This type of input is not supported") else: raise NotImplementedError(f"This type of input is not supported") diff --git a/qlib/contrib/meta/data_selection/model.py b/qlib/contrib/meta/data_selection/model.py index 068f15f9d6..7aaa0cad79 100644 --- a/qlib/contrib/meta/data_selection/model.py +++ b/qlib/contrib/meta/data_selection/model.py @@ -53,7 +53,12 @@ def __init__( max_epoch=100, seed=43, alpha=0.0, + loss_skip_thresh=50, ): + """ + loss_skip_size: int + The number of threshold to skip the loss calculation for each day. + """ self.step = step self.hist_step_n = hist_step_n self.clip_method = clip_method @@ -63,6 +68,7 @@ def __init__( self.max_epoch = max_epoch self.fitted = False self.alpha = alpha + self.loss_skip_thresh = loss_skip_thresh torch.manual_seed(seed) def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False): @@ -88,12 +94,14 @@ def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False): criterion = nn.MSELoss() loss = criterion(pred, meta_input["y_test"]) elif self.criterion == "ic_loss": - criterion = ICLoss() + criterion = ICLoss(self.loss_skip_thresh) try: - loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"], skip_size=50) + loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"]) except ValueError as e: get_module_logger("MetaModelDS").warning(f"Exception `{e}` when calculating IC loss") continue + else: + raise ValueError(f"Unknown criterion: {self.criterion}") assert not np.isnan(loss.detach().item()), "NaN loss!" diff --git a/qlib/contrib/meta/data_selection/utils.py b/qlib/contrib/meta/data_selection/utils.py index 7da5028085..2fddb00963 100644 --- a/qlib/contrib/meta/data_selection/utils.py +++ b/qlib/contrib/meta/data_selection/utils.py @@ -10,7 +10,11 @@ class ICLoss(nn.Module): - def forward(self, pred, y, idx, skip_size=50): + def __init__(self, skip_size=50): + super().__init__() + self.skip_size = skip_size + + def forward(self, pred, y, idx): """forward. FIXME: - Some times it will be a slightly different from the result from `pandas.corr()` @@ -33,7 +37,7 @@ def forward(self, pred, y, idx, skip_size=50): skip_n = 0 for start_i, end_i in zip(diff_point, diff_point[1:]): pred_focus = pred[start_i:end_i] # TODO: just for fake - if pred_focus.shape[0] < skip_size: + if pred_focus.shape[0] < self.skip_size: # skip some days which have very small amount of stock. skip_n += 1 continue @@ -50,6 +54,7 @@ def forward(self, pred, y, idx, skip_size=50): ) ic_all += ic_day if len(diff_point) - 1 - skip_n <= 0: + __import__("ipdb").set_trace() raise ValueError("No enough data for calculating IC") if skip_n > 0: get_module_logger("ICLoss").info( diff --git a/qlib/contrib/model/linear.py b/qlib/contrib/model/linear.py index 7fd3d156b5..15cdb739e9 100644 --- a/qlib/contrib/model/linear.py +++ b/qlib/contrib/model/linear.py @@ -63,6 +63,7 @@ def fit(self, dataset: DatasetH, reweighter: Reweighter = None): df_train = pd.concat([df_train, df_valid]) except KeyError: get_module_logger("LinearModel").info("include_valid=True, but valid does not exist") + df_train = df_train.dropna() if df_train.empty: raise ValueError("Empty data from dataset, please check your dataset config.") if reweighter is not None: diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index 2a476a657d..e0f883f094 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -1,25 +1,25 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - from __future__ import division from __future__ import print_function +import copy +from typing import Text, Union import numpy as np import pandas as pd -from typing import Text, Union -import copy -from ...utils import get_or_create_path -from ...log import get_module_logger - import torch import torch.nn as nn import torch.optim as optim -from .pytorch_utils import count_parameters -from ...model.base import Model +from qlib.workflow import R + from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from ...log import get_module_logger +from ...model.base import Model +from ...utils import get_or_create_path +from .pytorch_utils import count_parameters class GRU(Model): @@ -212,16 +212,31 @@ def fit( evals_result=dict(), save_path=None, ): - df_train, df_valid, df_test = dataset.prepare( - ["train", "valid", "test"], - col_set=["feature", "label"], - data_key=DataHandlerLP.DK_L, - ) - if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + # prepare training and validation data + dfs = { + k: dataset.prepare( + k, + col_set=["feature", "label"], + data_key=DataHandlerLP.DK_L, + ) + for k in ["train", "valid"] + if k in dataset.segments + } + df_train, df_valid = dfs.get("train", pd.DataFrame()), dfs.get("valid", pd.DataFrame()) + + # check if training data is empty + if df_train.empty: + raise ValueError("Empty training data from dataset, please check your dataset config.") + df_train = df_train.dropna() x_train, y_train = df_train["feature"], df_train["label"] - x_valid, y_valid = df_valid["feature"], df_valid["label"] + + # check if validation data is provided + if not df_valid.empty: + df_valid = df_valid.dropna() + x_valid, y_valid = df_valid["feature"], df_valid["label"] + else: + x_valid, y_valid = None, None save_path = get_or_create_path(save_path) stop_steps = 0 @@ -235,32 +250,42 @@ def fit( self.logger.info("training...") self.fitted = True + best_param = copy.deepcopy(self.gru_model.state_dict()) for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") self.train_epoch(x_train, y_train) self.logger.info("evaluating...") train_loss, train_score = self.test_epoch(x_train, y_train) - val_loss, val_score = self.test_epoch(x_valid, y_valid) - self.logger.info("train %.6f, valid %.6f" % (train_score, val_score)) evals_result["train"].append(train_score) - evals_result["valid"].append(val_score) - if val_score > best_score: - best_score = val_score - stop_steps = 0 - best_epoch = step - best_param = copy.deepcopy(self.gru_model.state_dict()) - else: - stop_steps += 1 - if stop_steps >= self.early_stop: - self.logger.info("early stop") - break + # evaluate on validation data if provided + if x_valid is not None and y_valid is not None: + val_loss, val_score = self.test_epoch(x_valid, y_valid) + self.logger.info("train %.6f, valid %.6f" % (train_score, val_score)) + evals_result["valid"].append(val_score) + + if val_score > best_score: + best_score = val_score + stop_steps = 0 + best_epoch = step + best_param = copy.deepcopy(self.gru_model.state_dict()) + else: + stop_steps += 1 + if stop_steps >= self.early_stop: + self.logger.info("early stop") + break self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) self.gru_model.load_state_dict(best_param) torch.save(best_param, save_path) + # Logging + rec = R.get_recorder() + for k, v_l in evals_result.items(): + for i, v in enumerate(v_l): + rec.log_metrics(step=i, **{k: v}) + if self.use_gpu: torch.cuda.empty_cache() @@ -292,6 +317,7 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): class GRUModel(nn.Module): + def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0): super().__init__() diff --git a/qlib/contrib/report/data/ana.py b/qlib/contrib/report/data/ana.py index 567ef311d5..d01e852cee 100644 --- a/qlib/contrib/report/data/ana.py +++ b/qlib/contrib/report/data/ana.py @@ -1,5 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +""" +Here we have a comprehensive set of analysis classes. + +Here is an example. + +.. code-block:: python + + from qlib.contrib.report.data.ana import FeaMeanStd + fa = FeaMeanStd(ret_df) + fa.plot_all(wspace=0.3, sub_figsize=(12, 3), col_n=5) + +""" import pandas as pd import numpy as np from qlib.contrib.report.data.base import FeaAnalyser @@ -152,6 +164,7 @@ def plot_single(self, col, ax): self._kurt[col].plot(ax=right_ax, label="kurt", color="green") right_ax.set_xlabel("") right_ax.set_ylabel("kurt") + right_ax.grid(None) # set the grid to None to avoid two layer of grid h1, l1 = ax.get_legend_handles_labels() h2, l2 = right_ax.get_legend_handles_labels() @@ -171,12 +184,15 @@ def plot_single(self, col, ax): ax.set_xlabel("") ax.set_ylabel("mean") ax.legend() + ax.tick_params(axis="x", rotation=90) right_ax = ax.twinx() self._std[col].plot(ax=right_ax, label="std", color="green") right_ax.set_xlabel("") right_ax.set_ylabel("std") + right_ax.tick_params(axis="x", rotation=90) + right_ax.grid(None) # set the grid to None to avoid two layer of grid h1, l1 = ax.get_legend_handles_labels() h2, l2 = right_ax.get_legend_handles_labels() diff --git a/qlib/contrib/report/data/base.py b/qlib/contrib/report/data/base.py index a91eda48e6..0861233b6d 100644 --- a/qlib/contrib/report/data/base.py +++ b/qlib/contrib/report/data/base.py @@ -14,6 +14,24 @@ class FeaAnalyser: def __init__(self, dataset: pd.DataFrame): + """ + + Parameters + ---------- + dataset : pd.DataFrame + + We often have multiple columns for dataset. Each column corresponds to one sub figure. + There will be a datatime column in the index levels. + Aggretation will be used for more summarized metrics overtime. + Here is an example of data: + + .. code-block:: + + return + datetime instrument + 2007-02-06 equity_tpx 0.010087 + equity_spx 0.000786 + """ self._dataset = dataset with TimeInspector.logt("calc_stat_values"): self.calc_stat_values() diff --git a/qlib/contrib/report/utils.py b/qlib/contrib/report/utils.py index 70de85198a..8d3d3fac9a 100644 --- a/qlib/contrib/report/utils.py +++ b/qlib/contrib/report/utils.py @@ -4,7 +4,7 @@ import pandas as pd -def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False): +def sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False): """sub_fig_generator. it will return a generator, each row contains sub graph @@ -13,7 +13,7 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None Parameters ---------- - sub_fs : + sub_figsize : the figure size of each subgraph in * subgraphs col_n : the number of subgraph in each row; It will generating a new graph after generating of subgraphs. @@ -33,7 +33,7 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None while True: fig, axes = plt.subplots( - row_n, col_n, figsize=(sub_fs[0] * col_n, sub_fs[1] * row_n), sharex=sharex, sharey=sharey + row_n, col_n, figsize=(sub_figsize[0] * col_n, sub_figsize[1] * row_n), sharex=sharex, sharey=sharey ) plt.subplots_adjust(wspace=wspace, hspace=hspace) axes = axes.reshape(row_n, col_n) diff --git a/qlib/contrib/rolling/base.py b/qlib/contrib/rolling/base.py index d179efb38b..05467a6be2 100644 --- a/qlib/contrib/rolling/base.py +++ b/qlib/contrib/rolling/base.py @@ -73,8 +73,8 @@ def __init__( The horizon of the prediction target. This is used to override the prediction horizon of the file. h_path : Optional[str] - the dumped data handler; - It may come from other data source. It will override the data handler in the config. + It is other data source that is dumped as a handler. It will override the data handler section in the config. + If it is not given, it will create a customized cache for the handler when `enable_handler_cache=True` test_end : Optional[str] the test end for the data. It is typically used together with the handler You can do the same thing with task_ext_conf in a more complicated way @@ -119,7 +119,7 @@ def _raw_conf(self) -> dict: with self.conf_path.open("r") as f: return yaml.safe_load(f) - def _replace_hanler_with_cache(self, task: dict): + def _replace_handler_with_cache(self, task: dict): """ Due to the data processing part in original rolling is slow. So we have to This class tries to add more feature @@ -159,13 +159,20 @@ def basic_task(self, enable_handler_cache: Optional[bool] = True): # - get horizon automatically from the expression!!!! raise NotImplementedError(f"This type of input is not supported") else: - self.logger.info("The prediction horizon is overrided") - task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [ - "Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1) - ] + if enable_handler_cache and self.h_path is not None: + self.logger.info("Fail to override the horizon due to data handler cache") + else: + self.logger.info("The prediction horizon is overrided") + if isinstance(task["dataset"]["kwargs"]["handler"], dict): + task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [ + "Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1) + ] + else: + self.logger.warning("Try to automatically configure the lablel but failed.") - if enable_handler_cache: - task = self._replace_hanler_with_cache(task) + if self.h_path is not None or enable_handler_cache: + # if we already have provided data source or we want to create one + task = self._replace_handler_with_cache(task) task = self._update_start_end_time(task) if self.task_ext_conf is not None: @@ -173,6 +180,16 @@ def basic_task(self, enable_handler_cache: Optional[bool] = True): self.logger.info(task) return task + def run_basic_task(self): + """ + Run the basic task without rolling. + This is for fast testing for model tunning. + """ + task = self.basic_task() + print(task) + trainer = TrainerR(experiment_name=self.exp_name) + trainer([task]) + def get_task_list(self) -> List[dict]: """return a batch of tasks for rolling.""" task = self.basic_task() diff --git a/qlib/contrib/rolling/ddgda.py b/qlib/contrib/rolling/ddgda.py index 25fb4c36e2..b62820ccea 100644 --- a/qlib/contrib/rolling/ddgda.py +++ b/qlib/contrib/rolling/ddgda.py @@ -80,6 +80,11 @@ def __init__( sim_task_model: UTIL_MODEL_TYPE = "gbdt", meta_1st_train_end: Optional[str] = None, alpha: float = 0.01, + loss_skip_thresh: int = 50, + fea_imp_n: Optional[int] = 30, + meta_data_proc: Optional[str] = "V01", + segments: Union[float, str] = 0.62, + hist_step_n: int = 30, working_dir: Optional[Union[str, Path]] = None, **kwargs, ): @@ -94,6 +99,15 @@ def __init__( alpha: float Setting the L2 regularization for ridge The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..) + loss_skip_thresh: int + The thresh to skip the loss calculation for each day. If the number of item is less than it, it will skip the loss on that day. + meta_data_proc : Optional[str] + How we process the meta dataset for learning meta model. + segments : Union[float, str] + if segments is a float: + The ratio of training data in the meta task dataset + if segments is a string: + it will try its best to put its data in training and ensure that the date `segments` is in the test set """ # NOTE: # the horizon must match the meaning in the base task template @@ -104,14 +118,22 @@ def __init__( super().__init__(**kwargs) self.working_dir = self.conf_path.parent if working_dir is None else Path(working_dir) self.proxy_hd = self.working_dir / "handler_proxy.pkl" + self.fea_imp_n = fea_imp_n + self.meta_data_proc = meta_data_proc + self.loss_skip_thresh = loss_skip_thresh + self.segments = segments + self.hist_step_n = hist_step_n def _adjust_task(self, task: dict, astype: UTIL_MODEL_TYPE): """ - some task are use for special purpose. + Base on the original task, we need to do some extra things. + For example: - GBDT for calculating feature importance - Linear or GBDT for calculating similarity - Datset (well processed) that aligned to Linear that for meta learning + + So we may need to change the dataset and model for the special purpose and other settings remains the same. """ # NOTE: here is just for aligning with previous implementation # It is not necessary for the current implementation @@ -119,12 +141,16 @@ def _adjust_task(self, task: dict, astype: UTIL_MODEL_TYPE): if astype == "gbdt": task["model"] = LGBM_MODEL if isinstance(handler, dict): + # We don't need preprocessing when using GBDT model for k in ["infer_processors", "learn_processors"]: if k in handler.setdefault("kwargs", {}): handler["kwargs"].pop(k) elif astype == "linear": task["model"] = LINEAR_MODEL - handler["kwargs"].update(PROC_ARGS) + if isinstance(handler, dict): + handler["kwargs"].update(PROC_ARGS) + else: + self.logger.warning("The handler can't be adjusted.") else: raise ValueError(f"astype not supported: {astype}") return task @@ -155,12 +181,15 @@ def _dump_data_for_proxy_model(self): The meta model will be trained upon the proxy forecasting model. This dataset is for the proxy forecasting model. """ - topk = 30 - fi = self._get_feature_importance() - col_selected = fi.nlargest(topk) + # NOTE: adjusting to `self.sim_task_model` just for aligning with previous implementation. + # In previous version. The data for proxy model is using sim_task_model's way for processing task = self._adjust_task(self.basic_task(enable_handler_cache=False), self.sim_task_model) task = replace_task_handler_with_cache(task, self.working_dir) + # if self.meta_data_proc is not None: + # else: + # # Otherwise, we don't need futher processing + # task = self.basic_task() dataset = init_instance_by_config(task["dataset"]) prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) @@ -168,12 +197,18 @@ def _dump_data_for_proxy_model(self): feature_df = prep_ds["feature"] label_df = prep_ds["label"] - feature_selected = feature_df.loc[:, col_selected.index] + if self.fea_imp_n is not None: + fi = self._get_feature_importance() + col_selected = fi.nlargest(self.fea_imp_n) + feature_selected = feature_df.loc[:, col_selected.index] + else: + feature_selected = feature_df - feature_selected = feature_selected.groupby("datetime", group_keys=False).apply( - lambda df: (df - df.mean()).div(df.std()) - ) - feature_selected = feature_selected.fillna(0.0) + if self.meta_data_proc == "V01": + feature_selected = feature_selected.groupby("datetime", group_keys=False).apply( + lambda df: (df - df.mean()).div(df.std()) + ) + feature_selected = feature_selected.fillna(0.0) df_all = { "label": label_df.reindex(feature_selected.index), @@ -223,7 +258,10 @@ def _train_meta_model(self, fill_method="max"): # 1) leverage the simplified proxy forecasting model to train meta model. # - Only the dataset part is important, in current version of meta model will integrate the - # the train_start for training meta model does not necessarily align with final rolling + # NOTE: + # - The train_start for training meta model does not necessarily align with final rolling + # But please select a right time to make sure the finnal rolling tasks are not leaked in the training data. + # - The test_start is automatically aligned to the next day of test_end. Validation is ignored. train_start = "2008-01-01" if self.train_start is None else self.train_start train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") @@ -249,9 +287,9 @@ def _train_meta_model(self, fill_method="max"): kwargs = dict( task_tpl=proxy_forecast_model_task, step=self.step, - segments=0.62, # keep test period consistent with the dataset yaml + segments=self.segments, # keep test period consistent with the dataset yaml trunc_days=1 + self.horizon, - hist_step_n=30, + hist_step_n=self.hist_step_n, fill_method=fill_method, rolling_ext_days=0, ) @@ -268,7 +306,13 @@ def _train_meta_model(self, fill_method="max"): with R.start(experiment_name=self.meta_exp_name): R.log_params(**kwargs) mm = MetaModelDS( - step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=30, seed=43, alpha=self.alpha + step=self.step, + hist_step_n=kwargs["hist_step_n"], + lr=0.001, + max_epoch=30, + seed=43, + alpha=self.alpha, + loss_skip_thresh=self.loss_skip_thresh, ) mm.fit(md) R.save_objects(model=mm) diff --git a/qlib/model/meta/task.py b/qlib/model/meta/task.py index 3204910010..a051acf146 100644 --- a/qlib/model/meta/task.py +++ b/qlib/model/meta/task.py @@ -51,3 +51,6 @@ def get_meta_input(self) -> object: Return the **processed** meta_info """ return self.meta_info + + def __repr__(self): + return f"MetaTask(task={self.task}, meta_info={self.meta_info})" diff --git a/qlib/utils/mod.py b/qlib/utils/mod.py index e539572606..4e0cb707f3 100644 --- a/qlib/utils/mod.py +++ b/qlib/utils/mod.py @@ -161,7 +161,13 @@ def init_instance_by_config( # path like 'file:////obj.pkl' pr = urlparse(config) if pr.scheme == "file": - pr_path = os.path.join(pr.netloc, pr.path) if bool(pr.path) else pr.netloc + + # To enable relative path like file://data/a/b/c.pkl. pr.netloc will be data + path = pr.path + if pr.netloc != "": + path = path.lstrip("/") + + pr_path = os.path.join(pr.netloc, path) if bool(pr.path) else pr.netloc with open(os.path.normpath(pr_path), "rb") as f: return pickle.load(f) else: diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py index c2265ea5db..cda3fdbe16 100644 --- a/qlib/workflow/cli.py +++ b/qlib/workflow/cli.py @@ -1,18 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import logging -import sys import os from pathlib import Path +import sys -import qlib import fire +from jinja2 import Template, meta import ruamel.yaml as yaml + +import qlib from qlib.config import C -from qlib.model.trainer import task_train -from qlib.utils.data import update_config from qlib.log import get_module_logger +from qlib.model.trainer import task_train from qlib.utils import set_log_with_config +from qlib.utils.data import update_config set_log_with_config(C.logging_config) logger = get_module_logger("qrun", logging.INFO) @@ -47,6 +49,39 @@ def sys_config(config, config_path): sys.path.append(str(Path(config_path).parent.resolve().absolute() / p)) +def render_template(config_path: str) -> str: + """ + render the template based on the environment + + Parameters + ---------- + config_path : str + configuration path + + Returns + ------- + str + the rendered content + """ + with open(config_path, "r") as f: + config = f.read() + # Set up the Jinja2 environment + template = Template(config) + + # Parse the template to find undeclared variables + env = template.environment + parsed_content = env.parse(config) + variables = meta.find_undeclared_variables(parsed_content) + + # Get context from os.environ according to the variables + context = {var: os.getenv(var, "") for var in variables if var in os.environ} + logger.info(f"Render the template with the context: {context}") + + # Render the template with the context + rendered_content = template.render(context) + return rendered_content + + # workflow handler function def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"): """ @@ -67,8 +102,9 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"): market: csi300 """ - with open(config_path) as fp: - config = yaml.safe_load(fp) + # Render the template + rendered_yaml = render_template(config_path) + config = yaml.safe_load(rendered_yaml) base_config_path = config.get("BASE_CONFIG_PATH", None) if base_config_path: