Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Revamp CrossValidationReport to use EstimatorReport #1091

Merged
merged 107 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 102 commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
76d72c6
feat: Revamp CrossValidationReport to use EstimatorReport
glemaitre Jan 10, 2025
98d94d7
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 10, 2025
8353b00
add progress bar
glemaitre Jan 10, 2025
b695666
iter
glemaitre Jan 11, 2025
4f8aa1c
make sure to use shared memory
glemaitre Jan 11, 2025
5a42157
fix progress bar glitches
glemaitre Jan 11, 2025
e4accc8
fix attribute sorting
glemaitre Jan 11, 2025
5ace5be
iter
glemaitre Jan 11, 2025
c6333ca
add metric accessor class
glemaitre Jan 11, 2025
82f64cb
iter
glemaitre Jan 11, 2025
b47786e
add custom metric
glemaitre Jan 11, 2025
fe110ff
iter
glemaitre Jan 11, 2025
35e01ea
roc curve
glemaitre Jan 11, 2025
4d0f16f
iter
glemaitre Jan 11, 2025
d46199e
iter
glemaitre Jan 11, 2025
ae11e23
some cache optimization
glemaitre Jan 12, 2025
e3a8a59
hash computation optimization
glemaitre Jan 12, 2025
68755aa
provide a way to send pos_label
glemaitre Jan 12, 2025
3175047
iter
glemaitre Jan 12, 2025
0256166
iter
glemaitre Jan 12, 2025
8e6a532
iter
glemaitre Jan 12, 2025
2635f56
more parallelism
glemaitre Jan 12, 2025
23fd1aa
bug
glemaitre Jan 14, 2025
3ae4da8
Merge remote-tracking branch 'origin/main' into cross_validation_4
glemaitre Jan 14, 2025
7dfe4a4
convert class decorator to function decorator
glemaitre Jan 14, 2025
31a9ba3
simplify progress bar
glemaitre Jan 14, 2025
071daa5
feat: Allow for nested progress bar
glemaitre Jan 14, 2025
8414c20
actually add the file to git
glemaitre Jan 14, 2025
5238fdc
more comment
glemaitre Jan 14, 2025
3a97df2
chore: More readable version of iterating
glemaitre Jan 14, 2025
57d7eb0
document simplify
glemaitre Jan 14, 2025
1199cb3
Merge branch 'progress_bar' into cross_validation_report_3
glemaitre Jan 14, 2025
c2637b2
Merge branch 'chore_cache_predictions_2' into cross_validation_report_3
glemaitre Jan 14, 2025
dc1e119
conflict
glemaitre Jan 14, 2025
3a1cd06
fix: Fix the error message regarding immutable attribute
glemaitre Jan 14, 2025
376f77b
Merge branch 'fix_err_message_mutability' into cross_validation_report_3
glemaitre Jan 14, 2025
05af3fb
fix messages
glemaitre Jan 14, 2025
c8f87f7
feat: Expose private API to to optimize cache optimization by passing…
glemaitre Jan 14, 2025
ae292a8
tests: Check for error message with invalid strings
glemaitre Jan 14, 2025
ef8d1fe
iter
glemaitre Jan 14, 2025
6d770cc
add documentation as suggested per auguste
glemaitre Jan 14, 2025
8f9cbfe
add test to handle pos_label in scorer
glemaitre Jan 14, 2025
79ec786
iter
glemaitre Jan 14, 2025
08146ae
Merge branch 'cache_optimization' into cross_validation_report_6
glemaitre Jan 14, 2025
0915f32
chore: Differentiate __repr__ and help for report and accessors
glemaitre Jan 14, 2025
8680166
another assert for help method name
glemaitre Jan 14, 2025
cf03f94
rename clean_cache to clear_cache
glemaitre Jan 14, 2025
57f12b0
Merge branch 'is/1103' into cross_validation_report_3
glemaitre Jan 14, 2025
a6c5ccb
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 14, 2025
332a692
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 14, 2025
cce35d4
test general behaviour and attribute of the report
glemaitre Jan 15, 2025
aa18678
check the caching mechanism
glemaitre Jan 15, 2025
ad4cec8
check help and repr for metrics and plot accessors
glemaitre Jan 15, 2025
7f44249
add test metrics binary
glemaitre Jan 15, 2025
ef98506
fix typo
glemaitre Jan 15, 2025
d6b4ab3
fix typo
glemaitre Jan 15, 2025
0ea179a
add test for single metrics
glemaitre Jan 15, 2025
c9dfc8a
covert report
glemaitre Jan 15, 2025
aab38f6
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 15, 2025
1d58817
iter
glemaitre Jan 15, 2025
88f76cc
more tests
glemaitre Jan 15, 2025
a06e3ff
more tests
glemaitre Jan 15, 2025
d575dca
check scoring_names
glemaitre Jan 15, 2025
51f77a3
scorer error name
glemaitre Jan 15, 2025
772bb2a
iter
glemaitre Jan 15, 2025
c5afcd0
more tests
glemaitre Jan 15, 2025
bab0e20
iter
glemaitre Jan 15, 2025
a881c56
update pos_label documentation
glemaitre Jan 15, 2025
fe01241
update pos_label documentation
glemaitre Jan 15, 2025
467699c
api: Prepend with name of variable that we modify
glemaitre Jan 15, 2025
44055f8
Merge branch 'is/1118' into cross_validation_report_3
glemaitre Jan 15, 2025
8b4f8c4
move new convention _
glemaitre Jan 15, 2025
a69ac60
documentation reporter
glemaitre Jan 15, 2025
4e5ffb6
documentation
glemaitre Jan 16, 2025
83635d0
revert example
glemaitre Jan 16, 2025
be5e826
fix
glemaitre Jan 16, 2025
c6426ed
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 16, 2025
ee01f4f
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 16, 2025
f6b66a9
iter
glemaitre Jan 16, 2025
13efcd8
cross-validation binary default
glemaitre Jan 16, 2025
9d50c0b
cross_validation multiclass defaults
glemaitre Jan 16, 2025
26169da
iter
glemaitre Jan 16, 2025
21ee6fc
more tests
glemaitre Jan 16, 2025
0df2b42
more tests
glemaitre Jan 16, 2025
28b6fea
iter
glemaitre Jan 16, 2025
2e933cf
roc tests
glemaitre Jan 16, 2025
293a77d
prediction error plot tests
glemaitre Jan 16, 2025
0b9a083
tests
glemaitre Jan 16, 2025
29cad73
add examples
glemaitre Jan 16, 2025
c1c9e92
rename cv to cv_splitter
glemaitre Jan 17, 2025
088e4e3
Update skore/src/skore/sklearn/_cross_validation/report.py
glemaitre Jan 17, 2025
b4523de
Update skore/src/skore/sklearn/_cross_validation/report.py
glemaitre Jan 17, 2025
9fa5f17
Update skore/src/skore/sklearn/_cross_validation/report.py
glemaitre Jan 17, 2025
95478c6
Update skore/src/skore/sklearn/_estimator/report.py
glemaitre Jan 17, 2025
c8eb0a4
Update skore/src/skore/sklearn/_cross_validation/__init__.py
glemaitre Jan 17, 2025
87c0d19
chore: Harmonize `message` to `note` in notes-related methods (#1143)
augustebaum Jan 17, 2025
3f6be36
fix(UI): Item card actions are now aligned (#1145)
rouk1 Jan 17, 2025
715cafc
feat: 404 page when skore-UI as not been built (#1142)
rouk1 Jan 17, 2025
574cb65
fix(UI): Ellipsis long item name (#1147)
rouk1 Jan 17, 2025
c598b86
ci: Fix timeout by upgrading scikit-learn to latest bugfix (#1141)
glemaitre Jan 17, 2025
94d43b4
iter
glemaitre Jan 17, 2025
25bb9ba
order matter
glemaitre Jan 17, 2025
3810c8d
Merge remote-tracking branch 'origin/main' into cross_validation_repo…
glemaitre Jan 18, 2025
2c68c9e
fix: Cache the pos_label=None when calling cache_predictions
glemaitre Jan 18, 2025
fe6bfa0
Merge branch 'bug_pos_label_none_cache' into cross_validation_report_3
glemaitre Jan 18, 2025
e989de0
iter
glemaitre Jan 18, 2025
ccafc68
Merge branch 'main' into cross_validation_report_3
augustebaum Jan 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions .github/workflows/backend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ jobs:
run: |
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
python -m pip install --upgrade "pip"
python -m pip install --upgrade "build"
python -m pip install --upgrade "scikit-learn ==${{ matrix.scikit-learn }}"
# adding `.*` to the version ensures that we install the latest version of
# scikit-learn that is compatible with the specified version
python -m pip install --upgrade "scikit-learn ==${{ matrix.scikit-learn }}.*"

# Install `skore` and its dependencies
python -m pip install --upgrade ".[test]"
python -m pip install --upgrade --upgrade-strategy=eager ".[test]"

# Uninstall the `skore` package itself
python -m pip uninstall -y "skore"
Expand All @@ -127,6 +129,11 @@ jobs:
# Install `skore` without its dependencies, which are present in the venv
wheel=(dist/*.whl); python -m pip install --force-reinstall --no-deps "${wheel}"

- name: Show dependencies versions
working-directory: skore/
run: |
python -c "import skore; skore.show_versions()"

- name: Test without coverage
if: ${{ ! matrix.coverage }}
timeout-minutes: 10
Expand All @@ -139,7 +146,7 @@ jobs:
working-directory: skore/
run: |
mkdir coverage
python -m pytest src/ tests/ --junitxml=coverage/coverage.xml --cov-config=pyproject.toml --cov | tee coverage/coverage.txt
python -m pytest -n auto src/ tests/ --junitxml=coverage/coverage.xml --cov-config=pyproject.toml --cov | tee coverage/coverage.txt

- name: Upload coverage reports
if: ${{ matrix.coverage && (github.event_name == 'pull_request') }}
Expand Down
1 change: 1 addition & 0 deletions skore-ui/src/components/ProjectViewCard.vue
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ onBeforeUnmount(() => {
& .actions {
display: flex;
flex-direction: row;
align-items: center;
gap: var(--spacing-4);
opacity: 0;
transition: opacity var(--animation-duration) var(--animation-easing);
Expand Down
4 changes: 4 additions & 0 deletions skore-ui/src/components/TreeAccordionItem.vue
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,14 @@ onMounted(() => {

.label {
display: flex;
overflow: hidden;
height: var(--label-height);
flex: 1;
flex-direction: row;
align-items: center;
cursor: pointer;
transition: background-color var(--animation-duration) var(--animation-easing);
white-space: nowrap;

& .children-indicator {
color: var(--color-text-secondary);
Expand All @@ -170,8 +172,10 @@ onMounted(() => {
}

& .text {
overflow: hidden;
border-radius: var(--radius-xs);
color: var(--color-text-primary);
text-overflow: ellipsis;
}

&.has-children {
Expand Down
10 changes: 8 additions & 2 deletions skore/src/skore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@
from rich.theme import Theme

from skore.project import Project, open
from skore.sklearn import CrossValidationReporter, EstimatorReport, train_test_split
from skore.sklearn import (
CrossValidationReport,
CrossValidationReporter,
EstimatorReport,
train_test_split,
)
from skore.utils._patch import setup_jupyter_display
from skore.utils._show_versions import show_versions

__all__ = [
"CrossValidationReporter",
"CrossValidationReport",
"EstimatorReport",
"open",
"Project",
Expand All @@ -34,4 +40,4 @@
}
)

console = Console(theme=skore_console_theme, width=79)
console = Console(theme=skore_console_theme, width=88)
14 changes: 7 additions & 7 deletions skore/src/skore/item/item_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,16 @@ def keys(self) -> list[str]:
"""
return list(self.storage.keys())

def set_item_note(self, key: str, message: str, *, version=-1):
def set_item_note(self, key: str, note: str, *, version=-1):
"""Attach a note to key ``key``.

Parameters
----------
key : str
The key of the item to annotate.
May be qualified with a version number through the ``version`` argument.
message : str
The message to be attached.
note : str
The note to be attached.
version : int, default=-1
The version of the key to annotate. Default is the latest version.

Expand All @@ -172,16 +172,16 @@ def set_item_note(self, key: str, message: str, *, version=-1):
KeyError
If the ``(key, version)`` couple does not exist.
TypeError
If ``key`` or ``message`` is not a string.
If ``key`` or ``note`` is not a string.
"""
if not isinstance(key, str):
raise TypeError(f"Key should be a string; got {type(key)}")
if not isinstance(message, str):
raise TypeError(f"Message should be a string; got {type(message)}")
if not isinstance(note, str):
raise TypeError(f"Note should be a string; got {type(note)}")

try:
old = self.storage[key]
old[version]["item"]["note"] = message
old[version]["item"]["note"] = note
self.storage[key] = old
except IndexError as e:
raise KeyError((key, version)) from e
Expand Down
16 changes: 7 additions & 9 deletions skore/src/skore/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,16 @@ def list_view_keys(self) -> list[str]:
"""
return self.view_repository.keys()

def set_note(self, key: str, message: str, *, version=-1):
def set_note(self, key: str, note: str, *, version=-1):
"""Attach a note to key ``key``.

Parameters
----------
key : str
The key of the item to annotate.
May be qualified with a version number through the ``version`` argument.
message : str
The message to be attached.
note : str
The note to be attached.
version : int, default=-1
The version of the key to annotate. Default is the latest version.

Expand All @@ -290,19 +290,17 @@ def set_note(self, key: str, message: str, *, version=-1):
KeyError
If the ``(key, version)`` couple does not exist.
TypeError
If ``key`` or ``message`` is not a string.
If ``key`` or ``note`` is not a string.

Examples
--------
# Annotate latest version of key "key"
>>> project.set_note("key", "message") # doctest: +SKIP
>>> project.set_note("key", "note") # doctest: +SKIP

# Annotate first version of key "key"
>>> project.set_note("key", "message", version=0) # doctest: +SKIP
>>> project.set_note("key", "note", version=0) # doctest: +SKIP
"""
return self.item_repository.set_item_note(
key=key, message=message, version=version
)
return self.item_repository.set_item_note(key=key, note=note, version=version)

def get_note(self, key: str, *, version=-1) -> Union[str, None]:
"""Retrieve a note previously attached to key ``key``.
Expand Down
2 changes: 2 additions & 0 deletions skore/src/skore/sklearn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Enhance `sklearn` functions."""

from skore.sklearn._cross_validation import CrossValidationReport
from skore.sklearn._estimator import EstimatorReport
from skore.sklearn.cross_validation import CrossValidationReporter
from skore.sklearn.train_test_split.train_test_split import train_test_split

__all__ = [
"train_test_split",
"CrossValidationReporter",
"CrossValidationReport",
"EstimatorReport",
]
16 changes: 7 additions & 9 deletions skore/src/skore/sklearn/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,13 @@ def _rich_repr(self, class_name, help_method_name):


class _BaseReport(_HelpMixin):
"""Base class for all reports."""

def _get_help_panel_title(self):
return (
f"[bold cyan]Tools to diagnose estimator "
f"{self.estimator_name_}[/bold cyan]"
)
return ""

def _get_help_legend(self):
return (
"[cyan](↗︎)[/cyan] higher is better [orange1](↘︎)[/orange1] lower is better"
)
return ""

def _get_attributes_for_help(self):
"""Get the public attributes to display in help."""
Expand All @@ -112,7 +109,7 @@ def _get_attributes_for_help(self):

# Group X and y attributes separately
value = getattr(self, name)
if name.startswith(("X_", "y_")):
if name.startswith(("X", "y")):
if value is not None: # Only include non-None X/y attributes
xy_attributes.append(name)
else:
Expand Down Expand Up @@ -304,11 +301,12 @@ def _get_cached_response_values(
response_method : str
The response method.

pos_label : str, default=None
pos_label : int, float, bool or str, default=None
The positive label.

data_source : {"test", "train", "X_y"}, default="test"
The data source to use.

- "test" : use the test set provided when creating the reporter.
- "train" : use the train set provided when creating the reporter.
- "X_y" : use the provided `X` and `y` to compute the metric.
Expand Down
16 changes: 16 additions & 0 deletions skore/src/skore/sklearn/_cross_validation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from skore.externals._pandas_accessors import _register_accessor
from skore.sklearn._cross_validation.metrics_accessor import (
_MetricsAccessor,
_PlotMetricsAccessor,
)
from skore.sklearn._cross_validation.report import (
CrossValidationReport,
)

# add the metrics accessor to the estimator report
_register_accessor("metrics", CrossValidationReport)(_MetricsAccessor)

# add the plot accessor to the metrics accessor
_register_accessor("plot", _MetricsAccessor)(_PlotMetricsAccessor)

__all__ = ["CrossValidationReport"]
Loading
Loading