From 48063e69f9b97b05041dd0e0d7358b86aba931bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 19 Aug 2024 12:35:04 +0200 Subject: [PATCH] WIP --- examples/004_constraints/constraints.py | 2 +- kafe2/config/kafe2.yaml | 1 + kafe2/fit/_base/plot.py | 64 +++++++++++++++++++++++-- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/examples/004_constraints/constraints.py b/examples/004_constraints/constraints.py index 544f4d3d..a0622912 100644 --- a/examples/004_constraints/constraints.py +++ b/examples/004_constraints/constraints.py @@ -88,6 +88,6 @@ def damped_harmonic_oscillator(t, y_0, l, r, g, c): # Optional: plot the fit results. plot = Plot(fit) -plot.plot(residual=True, asymmetric_parameter_errors=True) +plot.plot(pull=True, asymmetric_parameter_errors=True) plot.show() diff --git a/kafe2/config/kafe2.yaml b/kafe2/config/kafe2.yaml index d35e5d94..f187d64f 100644 --- a/kafe2/config/kafe2.yaml +++ b/kafe2/config/kafe2.yaml @@ -30,6 +30,7 @@ fit: y: '$y$' ratio_label: 'Ratio' residual_label: 'Residual' + pull_label: 'Pull' error_label: "%(model_label)s $\\pm 1\\sigma$" style: !include plot_style_color.yaml diff --git a/kafe2/fit/_base/plot.py b/kafe2/fit/_base/plot.py index b3e0d503..691692b2 100644 --- a/kafe2/fit/_base/plot.py +++ b/kafe2/fit/_base/plot.py @@ -163,6 +163,11 @@ class PlotAdapterBase: plot_adapter_method="plot_residual", target_axes="residual", ), + pull=dict( + plot_style_as="data", + plot_adapter_method="plot_pull", + target_axes="pull", + ), ) AVAILABLE_X_SCALES = ("linear",) @@ -626,6 +631,27 @@ def plot_residual(self, target_axes, error_contributions=("data",), **kwargs): **kwargs, ) + def plot_pull(self, target_axes, error_contributions=("data",), **kwargs): + """Plot the data/model ratio to a specified :py:obj:`matplotlib.axes.Axes` object. + + :param matplotlib.axes.Axes target_axes: The :py:obj:`matplotlib` axes used for plotting. + :param error_contributions: Which error contributions to include when plotting the data. + Can either be ``data``, ``'model'`` or both. + :type error_contributions: str or Tuple[str] + :param dict kwargs: Keyword arguments accepted by :py:obj:`matplotlib.pyplot.errorbar`. + :return: plot handle(s) + """ + + _xmin, _xmax = self.x_range + target_axes.fill_between([_xmin, _xmax], -2, 2, color=[0.0, 0.0, 0.0, 0.1]) + target_axes.fill_between([_xmin, _xmax], -1, 1, color=[0.0, 0.0, 0.0, 0.1]) + target_axes.hlines(y=0, xmin=_xmin, xmax=_xmax, colors="black", linestyles=":") + + _yerr = self._get_total_error(error_contributions) + + # TODO: how to handle case when x and y error/model differ? + return target_axes.errorbar(self.data_x, (self.model_y - self.data_y) / _yerr, xerr=0, yerr=0, **kwargs) + # Overridden by multi plot adapters def get_formatted_model_function(self, **kwargs): """return model function string""" @@ -1240,6 +1266,9 @@ def plot( residual=False, residual_range=None, residual_height_share=0.25, + pull=False, + pull_range=None, + pull_height_share=0.25, plot_width_share=0.5, font_scale=1.0, figsize=None, @@ -1268,8 +1297,15 @@ def plot( :return: dictionary containing information about the plotted objects :rtype: dict """ - if ratio and residual: - raise NotImplementedError("Cannot plot ratio and residual at the same time.") + _num_extra_plots = 0 + if ratio: + _num_extra_plots += 1 + if residual: + _num_extra_plots += 1 + if pull: + _num_extra_plots += 1 + if _num_extra_plots > 1: + raise NotImplementedError("Only one out of ratio, residual, and pull can be used at the same time.") with rc_context(kafe2_rc): rcParams["font.size"] *= font_scale @@ -1281,10 +1317,14 @@ def plot( _axes_keys += ("ratio",) _height_ratios[0] -= ratio_height_share _height_ratios.append(ratio_height_share) - elif residual: + if residual: _axes_keys += ("residual",) _height_ratios[0] -= residual_height_share _height_ratios.append(residual_height_share) + if pull: + _axes_keys += ("pull",) + _height_ratios[0] -= pull_height_share + _height_ratios.append(pull_height_share) _all_plot_results = [] for i in range(len(self._fits) if self._separate_figs else 1): @@ -1356,6 +1396,24 @@ def plot( _axis.set_ylim((_low, _high)) else: _axis.set_ylim(residual_range) + if pull: + _axis = self._current_axes["pull"] + _pull_label = kc("fit", "plot", "pull_label") + _axis.set_ylabel(_pull_label) + if pull_range is None: + _plot_adapters = self._get_plot_adapters()[i : i + 1] if self._separate_figs else self._get_plot_adapters() + _max_abs_deviation = 0 + for _plot_adapter in _plot_adapters: + _max_abs_deviation = max( + _max_abs_deviation, + np.max(np.abs((_plot_adapter.data_y - _plot_adapter.model_y) / _plot_adapter.data_yerr)), + ) + # Small gap between highest error bar and plot border: + _low = -_max_abs_deviation * 1.05 + _high = _max_abs_deviation * 1.05 + _axis.set_ylim((_low, _high)) + else: + _axis.set_ylim(pull_range) _all_plot_results.append(_plot_results)