Skip to content

Commit

Permalink
Update plot_cache_mechanism.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvaincom committed Jan 22, 2025
1 parent 9343198 commit 3e9097d
Showing 1 changed file with 48 additions and 18 deletions.
66 changes: 48 additions & 18 deletions examples/technical_details/plot_cache_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
os.environ["POLARS_ALLOW_FORKING_THREAD"] = "1"

# %%
# Loading some data
# =================
#
# First, we load a dataset from `skrub`. Our goal is to predict if a company paid a
# physician. The ultimate goal is to detect potential conflict of interest when it comes
Expand All @@ -35,9 +37,15 @@

# %%
#
# The dataset has over 70,000 records with only categorical features. Some categories
# are not well-defined. We use `skrub` to create a simple predictive model that handles
# this.
# The dataset has over 70,000 records with only categorical features.
# Some categories are not well defined.

# %%
# Caching with :class:`~skore.EstimatorReport` and :class:`~skore.CrossValidationReport`
# ======================================================================================
#
# We use `skrub` to create a simple predictive model that handles our dataset's
# challenges.
from skrub import tabular_learner

model = tabular_learner("classifier")
Expand All @@ -52,6 +60,11 @@
X_train, X_test, y_train, y_test = train_test_split(df, y, random_state=42)

# %%
# Caching the predictions for fast metric computation
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# First, let us focus on :class:`~skore.EstimatorReport`, as the same philosophy will
# apply to :class:`~skore.CrossValidationReport`.
#
# Let's explore how :class:`~skore.EstimatorReport` uses caching to speed up
# predictions. We start by training the model:
Expand Down Expand Up @@ -90,8 +103,9 @@

# %%
#
# Both approaches take similar time. Now watch what happens when we compute accuracy
# again:
# Both approaches take similar time.
#
# Now watch what happens when we compute accuracy again with our skore estimator report:
start = time.time()
result = report.metrics.accuracy()
end = time.time()
Expand All @@ -108,8 +122,9 @@

# %%
#
# The cache stores predictions by type and data source. This means metrics that use
# the same type of predictions will be faster. Let's try the precision metric:
# The cache stores predictions by type and data source. This means that computing
# metrics that use the same type of predictions will be faster.
# Let's try the precision metric:
start = time.time()
result = report.metrics.precision()
end = time.time()
Expand All @@ -121,15 +136,20 @@
# %%
# We observe that it takes only a few milliseconds to compute the precision because we
# don't need to re-compute the predictions and only have to compute the precision
# metric itself. Since the predictions are the bottleneck in terms of time, we observe
# metric itself.
# Since the predictions are the bottleneck in terms of computation time, we observe
# an interesting speedup.

# %%
# Caching all the possible predictions at once
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We can pre-compute all predictions at once using parallel processing:
report.cache_predictions(n_jobs=2)

# %%
#
# Now all possible predictions are stored. Any metric calculation will be much faster,
# Now, all possible predictions are stored. Any metric calculation will be much faster,
# even on different data (like the training set):
start = time.time()
result = report.metrics.log_loss(data_source="train")
Expand All @@ -140,6 +160,8 @@
print(f"Time taken: {end - start:.2f} seconds")

# %%
# Caching external data
# ^^^^^^^^^^^^^^^^^^^^^
#
# The report can also work with external data. We use `data_source="X_y"` to indicate
# that we want to pass those external data.
Expand All @@ -153,9 +175,9 @@

# %%
#
# The first calculation is slower than when using the internal train or test sets
# because it needs to compute a hash of the new data for later retrieval. Let's
# calculate it again:
# The first calculation of the above cell is slower than when using the internal train
# or test sets because it needs to compute a hash of the new data for later retrieval.
# Let's calculate it again:
start = time.time()
result = report.metrics.log_loss(data_source="X_y", X=X_test, y=y_test)
end = time.time()
Expand All @@ -166,8 +188,9 @@

# %%
#
# Much faster! The remaining time is related to the hash computation. Let's compute the
# ROC AUC on the same data:
# It is much faster for the second time as the predictions are cached!
# The remaining time corresponds to the hash computation.
# Let's compute the ROC AUC on the same data:
start = time.time()
result = report.metrics.roc_auc(data_source="X_y", X=X_test, y=y_test)
end = time.time()
Expand All @@ -178,8 +201,12 @@

# %%
# We observe that the computation is already efficient because it boils down to two
# computations: the hash of the data and the ROC-AUC metric. We save a lot of time
# because we don't need to re-compute the predictions.
# computations: the hash of the data and the ROC-AUC metric.
# We save a lot of time because we don't need to re-compute the predictions.

# %%
# Caching for plotting
# ^^^^^^^^^^^^^^^^^^^^
#
# The cache also speeds up plots. Let's create a ROC curve:
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -220,6 +247,9 @@
#
# It means that nothing is stored anymore in the cache.
#
# Caching with :class:`~skore.CrossValidationReport`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# :class:`~skore.CrossValidationReport` uses the same caching system for each fold
# in cross-validation by leveraging the previous :class:`~skore.EstimatorReport`:
from skore import CrossValidationReport
Expand All @@ -234,8 +264,8 @@

# %%
#
# Now all possible predictions are stored. Any metric calculation will be much faster,
# even on different data as we show for the :class:`~skore.EstimatorReport`.
# Now, all possible predictions are stored. Any metric calculation will be much faster,
# even on different data, as we showed for the :class:`~skore.EstimatorReport`.
start = time.time()
result = report.metrics.report_metrics(aggregate=["mean", "std"])
end = time.time()
Expand Down

0 comments on commit 3e9097d

Please sign in to comment.