Skip to content

Commit

Permalink
Update TabRepo_Reproducibility.ipynb (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
Innixma authored Jan 28, 2025
1 parent 440c03f commit ac920ce
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/TabRepo_Reproducibility.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,10 @@
"config = \"CatBoost_r1_BAG_L1\"\n",
"config_key = config.rsplit(\"_BAG_L1\", 1)[0]\n",
"config_type = config_hyperparameters[config_key][\"model_type\"]\n",
"config_hyperparameters = config_hyperparameters[config_key][\"hyperparameters\"]\n",
"conf_hyperparameters = config_hyperparameters[config_key][\"hyperparameters\"]\n",
"print(f\"Config: {config}\\n\"\n",
" f\"\\tType : {config_type}\\n\"\n",
" f\"\\tHyperparameters: {config_hyperparameters}\")\n",
" f\"\\tHyperparameters: {conf_hyperparameters}\")\n",
"\n",
"metrics = repo.metrics(datasets=[\"sensory\", \"Moneyball\"], configs=[\"CatBoost_r1_BAG_L1\", \"LightGBM_r41_BAG_L1\"])\n",
"with pd.option_context(\"display.max_rows\", None, \"display.max_columns\", None, \"display.width\", 1000):\n",
Expand All @@ -676,7 +676,7 @@
"y_val = repo.labels_val(dataset=dataset, fold=0)\n",
"print(f\"Ground Truth Val (dataset={dataset}, fold=0):\\n{y_val[:10]}\")\n",
"\n",
"df_ranks, df_ensemble_weights = repo.evaluate_ensemble(datasets=[dataset], configs=configs, ensemble_size=100)\n",
"df_ranks, df_ensemble_weights = repo.evaluate_ensembles(datasets=[dataset], ensemble_size=100)\n",
"print(f\"Ensemble rank per task:\\n{df_ranks}\")\n",
"\n",
"df_ensemble_weights_mean_sorted = df_ensemble_weights.mean(axis=0).sort_values(ascending=False)\n",
Expand Down

0 comments on commit ac920ce

Please sign in to comment.