From 24b3db936485e2eb74d5b367f87d9a55f6809981 Mon Sep 17 00:00:00 2001 From: Scott Shambaugh Date: Sun, 23 Jun 2024 19:08:08 -0600 Subject: [PATCH] Specify axis for wiggle stereograms --- CHANGELOG.md | 3 ++- src/mpl_stereo/AxesStereo.py | 29 ++++++++++++++++++++++------- tests/test_AxesStereo.py | 3 ++- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 092c895..cea5d47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,8 +14,9 @@ ## [Unreleased] ### Added +* Can specify a target axis for wiggle stereograms ### Changed -* Fix Axes3D.plot and .plot3D not working +* Fix `Axes3D.plot()` and `.plot3D()` not working ### Removed ## [0.8.1] - 2024-02-19 diff --git a/src/mpl_stereo/AxesStereo.py b/src/mpl_stereo/AxesStereo.py index f328011..779302c 100644 --- a/src/mpl_stereo/AxesStereo.py +++ b/src/mpl_stereo/AxesStereo.py @@ -379,7 +379,7 @@ def __init__(self, self.axs = (self.ax_left, self.ax_right) def wiggle(self, filepath: Union[str, Path], interval: float = 125, - *args: Any, **kwargs: dict[str, Any]): + ax: Optional[Axes] = None, *args: Any, **kwargs: dict[str, Any]): """ Save the figure as a wiggle stereogram. @@ -389,31 +389,46 @@ def wiggle(self, filepath: Union[str, Path], interval: float = 125, The filepath to save the figure to. interval : float The interval between frames in milliseconds, default 125. + ax : matplotlib.axes.Axes, optional + The target axes to plot the wiggle stereogram on. If None (default), + then will plot on the axes of a new Figure. *args : Any Additional arguments passed to animation.save. **kwargs : dict[str, Any] Additional keyword arguments passed to animation.save. """ filepath = Path(filepath) + if ax is None: + fig, ax_target = plt.subplots() + else: + ax_target = ax + fig = ax_target.figure + pos = ax_target.get_position() # Duplicate the figure buf = io.BytesIO() pickle.dump(self.fig, buf) buf.seek(0) - fig = pickle.load(buf) + fig_buffer = pickle.load(buf) + + fig.delaxes(ax_target) # Set up the axes to fill the figure - for ax in fig.axes: - ax.set_position([0.125, 0.11, 0.775, 0.77]) # Default full size + axs = fig_buffer.axes + for ax in axs: + ax.figure = fig + fig.axes.append(ax) + fig.add_axes(ax) + ax.set_position(pos) ax.set_visible(False) # Hide initially def update(frame): - fig.axes[frame].set_visible(True) + axs[frame].set_visible(True) # Don't hide the underlying left axis for 2D plots, since the right - # one has the y axis hidded + # one has the y axis hidden if self.is_3d and frame == 1: - fig.axes[0].set_visible(False) + axs[0].set_visible(False) return ax, ani = FuncAnimation(fig, update, frames=2, interval=interval) diff --git a/tests/test_AxesStereo.py b/tests/test_AxesStereo.py index 87f4381..4087a0e 100644 --- a/tests/test_AxesStereo.py +++ b/tests/test_AxesStereo.py @@ -244,6 +244,7 @@ def test_wiggle(): # Smoke test wiggle wiggle_filepath = Path('test.gif') x, y, z = _testdata()['trefoil'] + axstereo = AxesStereo2D() axstereo.plot(x, y, z) axstereo.wiggle(wiggle_filepath) @@ -374,7 +375,7 @@ def plotting_tests_wiggle(filepath): axstereo = AxesStereo2D() axstereo.ax_left.imshow(church_left_data) axstereo.ax_right.imshow(church_right_data) - axstereo.wiggle(filepath) + axstereo.wiggle(filepath, ax=None) if __name__ == '__main__':