Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Aug 19, 2024
1 parent 1969a73 commit 48063e6
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/004_constraints/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@ def damped_harmonic_oscillator(t, y_0, l, r, g, c):

# Optional: plot the fit results.
plot = Plot(fit)
plot.plot(residual=True, asymmetric_parameter_errors=True)
plot.plot(pull=True, asymmetric_parameter_errors=True)

plot.show()
1 change: 1 addition & 0 deletions kafe2/config/kafe2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ fit:
y: '$y$'
ratio_label: 'Ratio'
residual_label: 'Residual'
pull_label: 'Pull'
error_label: "%(model_label)s $\\pm 1\\sigma$"

style: !include plot_style_color.yaml
64 changes: 61 additions & 3 deletions kafe2/fit/_base/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ class PlotAdapterBase:
plot_adapter_method="plot_residual",
target_axes="residual",
),
pull=dict(
plot_style_as="data",
plot_adapter_method="plot_pull",
target_axes="pull",
),
)

AVAILABLE_X_SCALES = ("linear",)
Expand Down Expand Up @@ -626,6 +631,27 @@ def plot_residual(self, target_axes, error_contributions=("data",), **kwargs):
**kwargs,
)

def plot_pull(self, target_axes, error_contributions=("data",), **kwargs):
"""Plot the data/model ratio to a specified :py:obj:`matplotlib.axes.Axes` object.
:param matplotlib.axes.Axes target_axes: The :py:obj:`matplotlib` axes used for plotting.
:param error_contributions: Which error contributions to include when plotting the data.
Can either be ``data``, ``'model'`` or both.
:type error_contributions: str or Tuple[str]
:param dict kwargs: Keyword arguments accepted by :py:obj:`matplotlib.pyplot.errorbar`.
:return: plot handle(s)
"""

_xmin, _xmax = self.x_range
target_axes.fill_between([_xmin, _xmax], -2, 2, color=[0.0, 0.0, 0.0, 0.1])
target_axes.fill_between([_xmin, _xmax], -1, 1, color=[0.0, 0.0, 0.0, 0.1])
target_axes.hlines(y=0, xmin=_xmin, xmax=_xmax, colors="black", linestyles=":")

_yerr = self._get_total_error(error_contributions)

# TODO: how to handle case when x and y error/model differ?
return target_axes.errorbar(self.data_x, (self.model_y - self.data_y) / _yerr, xerr=0, yerr=0, **kwargs)

# Overridden by multi plot adapters
def get_formatted_model_function(self, **kwargs):
"""return model function string"""
Expand Down Expand Up @@ -1240,6 +1266,9 @@ def plot(
residual=False,
residual_range=None,
residual_height_share=0.25,
pull=False,
pull_range=None,
pull_height_share=0.25,
plot_width_share=0.5,
font_scale=1.0,
figsize=None,
Expand Down Expand Up @@ -1268,8 +1297,15 @@ def plot(
:return: dictionary containing information about the plotted objects
:rtype: dict
"""
if ratio and residual:
raise NotImplementedError("Cannot plot ratio and residual at the same time.")
_num_extra_plots = 0
if ratio:
_num_extra_plots += 1
if residual:
_num_extra_plots += 1
if pull:
_num_extra_plots += 1
if _num_extra_plots > 1:
raise NotImplementedError("Only one out of ratio, residual, and pull can be used at the same time.")

with rc_context(kafe2_rc):
rcParams["font.size"] *= font_scale
Expand All @@ -1281,10 +1317,14 @@ def plot(
_axes_keys += ("ratio",)
_height_ratios[0] -= ratio_height_share
_height_ratios.append(ratio_height_share)
elif residual:
if residual:
_axes_keys += ("residual",)
_height_ratios[0] -= residual_height_share
_height_ratios.append(residual_height_share)
if pull:
_axes_keys += ("pull",)
_height_ratios[0] -= pull_height_share
_height_ratios.append(pull_height_share)

_all_plot_results = []
for i in range(len(self._fits) if self._separate_figs else 1):
Expand Down Expand Up @@ -1356,6 +1396,24 @@ def plot(
_axis.set_ylim((_low, _high))
else:
_axis.set_ylim(residual_range)
if pull:
_axis = self._current_axes["pull"]
_pull_label = kc("fit", "plot", "pull_label")
_axis.set_ylabel(_pull_label)
if pull_range is None:
_plot_adapters = self._get_plot_adapters()[i : i + 1] if self._separate_figs else self._get_plot_adapters()
_max_abs_deviation = 0
for _plot_adapter in _plot_adapters:
_max_abs_deviation = max(
_max_abs_deviation,
np.max(np.abs((_plot_adapter.data_y - _plot_adapter.model_y) / _plot_adapter.data_yerr)),
)
# Small gap between highest error bar and plot border:
_low = -_max_abs_deviation * 1.05
_high = _max_abs_deviation * 1.05
_axis.set_ylim((_low, _high))
else:
_axis.set_ylim(pull_range)

_all_plot_results.append(_plot_results)

Expand Down

0 comments on commit 48063e6

Please sign in to comment.