Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major Refactor: Add save/load to dir, code refactor, etc. #86

Merged
merged 117 commits into from
Jan 14, 2025

Conversation

Innixma
Copy link
Collaborator

@Innixma Innixma commented Jan 11, 2025

Issue #, if available:

Description of changes:

This PR contains a major refactor to streamline a lot of logic that was previously hacky or exclusive to the scripts code and therefore hard to use by casual users.

Note that I have verified that the results of evaluate_baselines.py is identical to mainline, so these changes do not impact the results of our simulations.

  • Added evaluate_ensembles / evaluate_ensemble to replace the previous evaluate_ensemble, adding more flexibility such as the option to specifying a time_limit for the ensemble, which previously was only possible through a hack in the scripts code.
  • Added save/load logic for Repo via from_dir and to_dir to avoid relying on pickle files. This dramatically improves portability and the ability for others to share their repo artifacts, as previously it was very involved.
  • Added save/load logic for SimulationContext via from_dir and to_dir so that we don't rely on pickle files.
  • Added save/load logic for Context via from_json and to_json to avoid relying on pickle files.
  • Generally improved the consistency and ease of formatting the input files for a repo/context.
  • Lots of cleanup of baselines.py to use the enhanced repo methods rather than hard-coding the important logic into the script code.
  • Added repo.from_raw as a greatly simplified way to initialize a new repo with benchmark results. Refer to run_quickstart_from_scratch for details on how this simplifies the process.
  • Added type hints in many places along with improved docstrings
  • Added repo.compare_metrics and repo.plot_overall_rank_comparison. These are experimental methods with TODOs. They are part of logic that will tie into ease of comparison and evaluation. I haven't integrated them into the rest of the scripts yet, as I wanted to avoid edits to evaluate_baselines.py in this PR.
  • Switched time_utils to use dataset instead of tid. Now all scripts/functions are consistently using dataset.
  • Added support to return validation error as an additional output when simulating ensembles.
  • Added support to optimize on the test error instead of the val error (cheater mode) during ensembling for debugging purposes such as measuring the generalization gap.
  • Added unit tests to verify many advanced features for repository equivalence checks.
  • Added unit tests verifying that repositories saved and loaded with to_dir and from_dir are identical, even when they are saved and loaded from a new directory.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@Innixma Innixma requested a review from geoalgo January 11, 2025 01:33
@Innixma Innixma changed the title [WIP] Major Refactor Major Refactor: Add save/load to dir, code refactor, etc. Jan 11, 2025
@Innixma Innixma marked this pull request as ready for review January 11, 2025 01:34
@Innixma Innixma added this to the TabRepo 2.0 milestone Jan 11, 2025
Copy link
Collaborator

@geoalgo geoalgo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, looks mostly good to me. It is a bit hard to review so much code at once I hope my comments can still be useful :-) The main thing that struck me as not great from a user POV is the dataframe with single row and the cases where a dict is returned instead of a datacase. Thanks a lot for the continous improvements on tabrepo!

@@ -92,7 +92,7 @@ To evaluate an ensemble of any list of configuration, you can run the following:
```python
from tabrepo import load_repository
repo = load_repository("D244_F3_C1530_30")
print(repo.evaluate_ensemble(datasets=["Australian"], configs=["CatBoost_r22_BAG_L1", "RandomForest_r12_BAG_L1"]))
print(repo.evaluate_ensemble(dataset="Australian", fold=0, configs=["CatBoost_r22_BAG_L1", "RandomForest_r12_BAG_L1"]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

import copy
import itertools
from typing import List, Optional, Tuple
from typing import List
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really not for this PR but as a FIY, we can use stuff like list[float] for annotations and drop the import in recent python versions.

edit: I see you are aware of this since they are edits later that remove Dict

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I am planning to remove all of the instances, but it would have been a lot more changes that weren't functional changes so I tried to avoid it in this PR. Can do in a follow-up though

@@ -34,110 +28,99 @@ class ResultRow:
normalized_error: float
time_train_s: float
time_infer_s: float
metric_error_val: float = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
metric_error_val: float = None
metric_error_val: float | None = None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general practice, I'm curious what you think about the | None type hint practice.

To me it feels redundant, because any time = None is present, it by definition means that | None is part of the type, and so foo: float | None = None provides the same information as foo: float = None. Also, I'm pretty sure PyCharm treats it the same way in terms of how it interacts with the IDE.

Of course, it doesn't "hurt" to have | None, but it feels like extra clutter for the sake of clutter (and it will edit probably over 1000 LoC since None default is extremely common)

I guess PEP 484 seems to suggest being explicit: https://peps.python.org/pep-0484/#union-types

fine to do either way, but I'll probably not make the edits in this PR. This kind of change is something best done in bulk and as the only contribution in a PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine to do either way, but I'll probably not make the edits in this PR. This kind of change is something best done in bulk and as the only contribution in a PR.

Sure works for me. I dont have a strong opinion but following PEP as much as possible is a good idea.

config_selected: list = None
seed: int = None
metadata: dict = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
metadata: dict = None
metadata: dict | None = None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

scripts/baseline_comparison/baselines.py Outdated Show resolved Hide resolved
Comment on lines 116 to 157
results = scorer.compute_errors(configs=configs)
metric_error = results[task]["metric_error"]
ensemble_weights = results[task]["ensemble_weights"]
metric_error_val = results[task]["metric_error_val"]

dataset_info = self.dataset_info(dataset=dataset)
metric = dataset_info["metric"]
problem_type = dataset_info["problem_type"]

# select configurations used in the ensemble as infer time only depends on the models with non-zero weight.
fail_if_missing = self._config_fallback is None
config_selected_ensemble = [
config for i, config in enumerate(configs) if ensemble_weights[i] != 0
]

runtimes = get_runtime(
repo=self,
dataset=dataset,
fold=fold,
config_names=configs,
runtime_col='time_train_s',
fail_if_missing=fail_if_missing,
)
latencies = get_runtime(
repo=self,
dataset=dataset,
fold=fold,
config_names=config_selected_ensemble,
runtime_col='time_infer_s',
fail_if_missing=fail_if_missing,
)
time_train_s = sum(runtimes.values())
time_infer_s = sum(latencies.values())

output_dict = {
"metric_error": [metric_error],
"metric": [metric],
"time_train_s": [time_train_s],
"time_infer_s": [time_infer_s],
"problem_type": [problem_type],
"metric_error_val": [metric_error_val],
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we put this code in the main class, I think it would be good to put high level comments on the key blocks to make the code more readable (could be a todo).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some inline comments

tid, fold = task_to_tid_fold(task=task)
dataset = self.tid_to_dataset_name_dict[tid]
return self.ensemble_scorer.evaluate_task(dataset=dataset, fold=fold, models=models)

def compute_errors(self, configs: List[str]) -> Tuple[Dict[str, float], Dict[str, np.array]]:
def compute_errors(self, configs: list[str]) -> dict[str, dict[str, ...]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ellipsis is probably unintended? If you mean to indicate a type too complex, this would be better as it would be the right type:

Suggested change
def compute_errors(self, configs: list[str]) -> dict[str, dict[str, ...]]:
def compute_errors(self, configs: list[str]) -> dict[str, dict[str, object]]:

For cases with complex outputs, consider using a dataclass as it is generally recommended over dict (it has a lot of advantages)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed object makes more sense. Regarding dataclass, I generally agree but it would be a non-trivial lift to refactor I think. Will consider this later on for TabRepo 2.0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some reason git isn't allowing me to commit the suggestion directly, so I sent a commit separately

method = framework_type if framework_type else "All"
if prefix is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the scripts to analyse the results from the paper still work with those changes? You mention the results are the same float wise, I wonder if the names are also compatible (If not we should mention this in the readme)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes they do work and are unchanged. prefix and all are always None and False for evaluate_baselines.py. I have other code that actually sets these to non-default values when I was testing TabPFNMix, but that can be a follow-up PR as it would change evaluate_baselines.py.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names are identical

@@ -27,11 +27,15 @@
class Experiment:
expname: str # name of the parent experiment used to store the file
name: str # name of the specific experiment, e.g. "localsearch"
run_fun: Callable[[], List[ResultRow]] # function to execute to obtain results
run_fun: Callable[..., List[ResultRow]] # function to execute to obtain results
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any or object is probably better than the ellipsis as ellipsis is a type that wont match here right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember checking this and ... was preferred. Here is what Google's AI response was:

When using the Callable type hint in Python, you can use either Callable[...] or Callable[object] to represent a callable that takes any number of arguments and returns any type.
Here's the difference:
Callable[...]:
This is the more concise way to represent a callable with any signature. It means that the callable can take any number of arguments of any type and return any type.
Callable[object]:
This is the more explicit way to represent a callable with any signature. It means that the callable can take any number of arguments of any type and return a value of type object.
In most cases, Callable[...] is preferred due to its brevity.


def data(self, ignore_cache: bool = False):
def data(self, ignore_cache: bool = False) -> pd.DataFrame:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I had forgotten this code 🙈

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old version had some pretty crazy weirdness going on due to what I think was mutable lambda functions when the lambda's are taking in variables that are edited after the lambda's creation, causing them to use the new values rather than the old intended values. By instead passing the kwargs at the time of calling .data(), this avoids this problem.

@@ -17,5 +17,5 @@ def f():
return pd.DataFrame({"a": [1, 2], "b": [3, 4]})

for ignore_cache in [True, False]:
res = cache_function_dataframe(f, "f", ignore_cache=ignore_cache)
res = cache_function_dataframe(f, "f", cache_path="tmp_cache_dir", ignore_cache=ignore_cache)
Copy link
Collaborator

@geoalgo geoalgo Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is tmp_cache_dir defined?

Oh I see now, probably would make sense to use a true tempdir, makes a lot of sense to use one to avoid side effects.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember off the top of my head how to pass a tempdir as an input to a function call, feel free to send a PR if you happen to know. For now adding a TODO.

# TODO: Add fillna
# TODO: Docstring
# Q:Whether to keep these functions a part of TabRepo or keep them separate as a part of new fit()-package
def compare_metrics(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I forgot to make this point in my review. I would highly recommend to move compare_metrics and plot_overall_rank_comparison into utils and out of repository base class as those methods are highly complex (they double the LOC of the class) and only specific to one use-case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, although I'd prefer to do this in a follow-up PR. I think I'll keep it as is in this PR and make a dedicated PR to move this logic so it is easier to review.

These two methods are WIP, so I'm planning for them to change quite a bit before 2.0.

@Innixma Innixma merged commit 329deab into main Jan 14, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants