Skip to content

Commit

Permalink
Merge branch 'main' into cross_validation_report_3
Browse files Browse the repository at this point in the history
  • Loading branch information
augustebaum authored Jan 20, 2025
2 parents e989de0 + 706d620 commit ccafc68
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 3 deletions.
8 changes: 7 additions & 1 deletion skore/src/skore/sklearn/_plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def help(self):

console.print(self._create_help_panel())

def __repr__(self):
def __str__(self):
"""Return a string representation using rich."""
console = Console(file=StringIO(), force_terminal=False)
console.print(
Expand All @@ -93,6 +93,12 @@ def __repr__(self):
)
return console.file.getvalue()

def __repr__(self):
"""Return a string representation using rich."""
console = Console(file=StringIO(), force_terminal=False)
console.print(f"[cyan]skore.{self.__class__.__name__}(...)[/cyan]")
return console.file.getvalue()


class _ClassifierCurveDisplayMixin:
"""Mixin class to be used in Displays requiring a binary classifier.
Expand Down
28 changes: 26 additions & 2 deletions skore/tests/unit/sklearn/plot/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,31 @@ def test_display_help(pyplot, capsys, plot_func, estimator, dataset):
assert f"{display.__class__.__name__}" in captured.out


@pytest.mark.parametrize(
"plot_func, estimator, dataset",
[
("roc", LogisticRegression(), make_classification(random_state=42)),
(
"precision_recall",
LogisticRegression(),
make_classification(random_state=42),
),
("prediction_error", LinearRegression(), make_regression(random_state=42)),
],
)
def test_display_str(pyplot, plot_func, estimator, dataset):
"""Check that __str__ returns a string starting with the expected prefix."""
X_train, X_test, y_train, y_test = train_test_split(*dataset, random_state=42)
report = EstimatorReport(
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
)
display = getattr(report.metrics.plot, plot_func)()

str_str = str(display)
assert f"{display.__class__.__name__}" in str_str
assert "display.help()" in str_str


@pytest.mark.parametrize(
"plot_func, estimator, dataset",
[
Expand All @@ -52,8 +77,7 @@ def test_display_repr(pyplot, plot_func, estimator, dataset):
display = getattr(report.metrics.plot, plot_func)()

repr_str = repr(display)
assert f"{display.__class__.__name__}" in repr_str
assert "display.help()" in repr_str
assert f"skore.{display.__class__.__name__}(...)" in repr_str


@pytest.mark.parametrize(
Expand Down
25 changes: 25 additions & 0 deletions sphinx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

import sys
import os
from sphinx_gallery.sorting import ExplicitOrder

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath("sphinxext"))
from github_link import make_linkcode_resolve # noqa

project = "skore"
copyright = "2024, Probabl"
author = "Probabl"
Expand All @@ -23,6 +30,7 @@
"sphinx.ext.autosummary",
"sphinx.ext.githubpages",
"sphinx.ext.intersphinx",
"sphinx.ext.linkcode",
"sphinx_design",
"sphinx_gallery.gen_gallery",
"sphinx_copybutton",
Expand Down Expand Up @@ -172,3 +180,20 @@ def reset_mpl(gallery_conf, fname):
# Sphinx-Copybutton configuration
copybutton_prompt_text = r">>> |\.\.\. |\$ "
copybutton_prompt_is_regexp = True

# -- Options for github link for what's new -----------------------------------

# Config for sphinx_issues
issues_uri = "https://github.com/probabl-ai/skore/issues/{issue}"
issues_github_path = "probabl-ai/skore"
issues_user_uri = "https://github.com/{user}"

# The following is used by sphinx.ext.linkcode to provide links to github
linkcode_resolve = make_linkcode_resolve(
"skore",
(
"https://github.com/probabl-ai/"
"skore/blob/{revision}/"
"{package}/src/skore/{path}#L{lineno}"
),
)
89 changes: 89 additions & 0 deletions sphinx/sphinxext/github_link.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Link docs with GitHub lines.
This file is copied from https://github.com/scikit-learn-contrib/imbalanced-learn/blob/master/doc/sphinxext/github_link.py
"""

import inspect
import os
import subprocess
import sys
from functools import partial
from operator import attrgetter

REVISION_CMD = "git rev-parse --short HEAD"


def _get_git_revision():
try:
revision = subprocess.check_output(REVISION_CMD.split()).strip()
except (subprocess.CalledProcessError, OSError):
print("Failed to execute git to get revision")
return None
return revision.decode("utf-8")


def _linkcode_resolve(domain, info, package, url_fmt, revision):
"""Determine a link to online source for a class/method/function
This is called by sphinx.ext.linkcode
An example with a long-untouched module that everyone has
>>> _linkcode_resolve('py', {'module': 'tty',
... 'fullname': 'setraw'},
... package='tty',
... url_fmt='https://hg.python.org/cpython/file/'
... '{revision}/Lib/{package}/{path}#L{lineno}',
... revision='xxxx')
'https://hg.python.org/cpython/file/xxxx/Lib/tty/tty.py#L18'
"""

if revision is None:
return
if domain not in ("py", "pyx"):
return
if not info.get("module") or not info.get("fullname"):
return

class_name = info["fullname"].split(".")[0]
module = __import__(info["module"], fromlist=[class_name])
obj = attrgetter(info["fullname"])(module)

# Unwrap the object to get the correct source
# file in case that is wrapped by a decorator
obj = inspect.unwrap(obj)

try:
fn = inspect.getsourcefile(obj)
except Exception:
fn = None
if not fn:
try:
fn = inspect.getsourcefile(sys.modules[obj.__module__])
except Exception:
fn = None
if not fn:
return

fn = os.path.relpath(fn, start=os.path.dirname(__import__(package).__file__))
try:
lineno = inspect.getsourcelines(obj)[1]
except Exception:
lineno = ""
return url_fmt.format(revision=revision, package=package, path=fn, lineno=lineno)


def make_linkcode_resolve(package, url_fmt):
"""Returns a linkcode_resolve function for the given URL format
revision is a git commit reference (hash or name)
package is the name of the root module of the package
url_fmt is along the lines of ('https://github.com/USER/PROJECT/'
'blob/{revision}/{package}/'
'{path}#L{lineno}')
"""
revision = _get_git_revision()
return partial(
_linkcode_resolve, revision=revision, package=package, url_fmt=url_fmt
)

0 comments on commit ccafc68

Please sign in to comment.