From 8628299d2720b303126e14338052d4bf58c80cd3 Mon Sep 17 00:00:00 2001 From: Josef Heinen Date: Thu, 2 Jan 2025 14:57:12 +0100 Subject: [PATCH] Add support for Marimo notebooks --- gr/__init__.py | 10 +++-- gr/pygr/mlab.py | 109 ++++++++++++++++++++++++------------------------ 2 files changed, 61 insertions(+), 58 deletions(-) diff --git a/gr/__init__.py b/gr/__init__.py index f28347d..c88f4f3 100644 --- a/gr/__init__.py +++ b/gr/__init__.py @@ -2733,11 +2733,14 @@ def show(): emergencyclosegks() if _mime_type == 'svg': try: - data = open('gks.svg', 'rb').read() + data = open('gks.svg', 'r').read() except IOError: return None if not data: return None + if 'marimo' in sys.modules: + from marimo import Html + return Html(data) content = SVG(data=data) display(content) elif _mime_type == 'png': @@ -4520,6 +4523,7 @@ def version(): VOLUME_WITHOUT_BORDER = 0 VOLUME_WITH_BORDER = 1 -# automatically switch to inline graphics in Jupyter Notebooks -if 'ipykernel' in sys.modules: +# automatically switch to inline graphics in Jupyter or Marimo notebooks +if 'ipykernel' in sys.modules or 'marimo' in sys.modules: inline() + diff --git a/gr/pygr/mlab.py b/gr/pygr/mlab.py index 23912b3..524b11e 100644 --- a/gr/pygr/mlab.py +++ b/gr/pygr/mlab.py @@ -78,7 +78,7 @@ def plot(*args, **kwargs): _plt.args += _plot_args(args) else: _plt.args = _plot_args(args) - _plot_data(kind='line') + return _plot_data(kind='line') @_close_gks_on_error @@ -107,7 +107,7 @@ def oplot(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args += _plot_args(args) - _plot_data(kind='line') + return _plot_data(kind='line') @_close_gks_on_error @@ -150,7 +150,7 @@ def step(*args, **kwargs): _plt.args += _plot_args(args) else: _plt.args = _plot_args(args) - _plot_data(kind='step') + return _plot_data(kind='step') @_close_gks_on_error @@ -192,7 +192,7 @@ def scatter(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args, fmt='xyac') - _plot_data(kind='scatter') + return _plot_data(kind='scatter') @_close_gks_on_error @@ -220,7 +220,7 @@ def quiver(x, y, u, v, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args((x, y, u, v), fmt='xyuv') - _plot_data(kind='quiver') + return _plot_data(kind='quiver') @_close_gks_on_error @@ -248,7 +248,7 @@ def polar(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args) - _plot_data(kind='polar') + return _plot_data(kind='polar') @_close_gks_on_error @@ -278,7 +278,7 @@ def trisurf(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args, fmt='xyzc') - _plot_data(kind='trisurf') + return _plot_data(kind='trisurf') @_close_gks_on_error @@ -309,7 +309,7 @@ def tricont(x, y, z, *args, **kwargs): _plt.kwargs.update(kwargs) args = [x, y, z] + list(args) _plt.args = _plot_args(args, fmt='xyzc') - _plot_data(kind='tricont') + return _plot_data(kind='tricont') @_close_gks_on_error @@ -340,7 +340,7 @@ def stem(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args) - _plot_data(kind='stem') + return _plot_data(kind='stem') def _hist(x, nbins=0, weights=None): @@ -387,7 +387,7 @@ def histogram(x, num_bins=0, weights=None, **kwargs): _plt.kwargs.update(kwargs) hist, bins = _hist(x, nbins=num_bins, weights=weights) _plt.args = [(np.array(bins), np.array(hist), None, None, "")] - _plot_data(kind='hist') + return _plot_data(kind='hist') @_close_gks_on_error @@ -430,7 +430,7 @@ def contour(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args, fmt='xyzc') - _plot_data(kind='contour') + return _plot_data(kind='contour') @_close_gks_on_error @@ -473,7 +473,7 @@ def contourf(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args, fmt='xyzc') - _plot_data(kind='contourf') + return _plot_data(kind='contourf') @_close_gks_on_error @@ -501,7 +501,7 @@ def hexbin(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args) - _plot_data(kind='hexbin') + return _plot_data(kind='hexbin') @_close_gks_on_error @@ -538,7 +538,7 @@ def heatmap(data, **kwargs): xlim = _plt.kwargs.get('xlim', None) ylim = _plt.kwargs.get('ylim', None) _plt.args = [(xlim, ylim, data, None, "")] - _plot_data(kind='heatmap') + return _plot_data(kind='heatmap') @_close_gks_on_error @@ -575,7 +575,7 @@ def polar_heatmap(data, **kwargs): rlim = _plt.kwargs.get('rlim', None) philim = _plt.kwargs.get('philim', None) _plt.args = [(rlim, philim, data, None, "")] - _plot_data(kind='polar_heatmap') + return _plot_data(kind='polar_heatmap') @_close_gks_on_error @@ -619,7 +619,7 @@ def shade(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args, fmt='xys') - _plot_data(kind='shade') + return _plot_data(kind='shade') @_close_gks_on_error @@ -661,7 +661,7 @@ def wireframe(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args, fmt='xyzc') - _plot_data(kind='wireframe') + return _plot_data(kind='wireframe') @_close_gks_on_error @@ -703,7 +703,7 @@ def surface(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args, fmt='xyzc') - _plot_data(kind='surface') + return _plot_data(kind='surface') @_close_gks_on_error @@ -812,7 +812,7 @@ def bar(y, *args, **kwargs): new_args.append(c) _plt.args = _plot_args(new_args, fmt='y') - _plot_data(kind='bar') + return _plot_data(kind='bar') def polar_histogram(*args, **kwargs): @@ -893,10 +893,7 @@ def polar_histogram(*args, **kwargs): """ global _plt - if _plt.kwargs.get('ax', False) is False: - temp_ax = False - else: - temp_ax = True + saved_ax = _plt.kwargs.get('ax', False) _plt.kwargs['ax'] = True _plt.args = _plot_args(args, fmt='xys') @@ -1307,9 +1304,11 @@ def find_max(classes, normalization): if _plt.kwargs['normalization'] == 'pdf': _plt.kwargs['norm_factor'] = total - _plot_data(kind='polar_histogram') - _plt.kwargs['ax'] = temp_ax - del temp_ax + plt = _plot_data(kind='polar_histogram') + _plt.kwargs['ax'] = saved_ax + del saved_ax + + return plt @_close_gks_on_error @@ -1333,7 +1332,7 @@ def plot3(*args, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = _plot_args(args, fmt='xyac') - _plot_data(kind='plot3') + return _plot_data(kind='plot3') @_close_gks_on_error @@ -1367,7 +1366,7 @@ def scatter3(x, y, z, c=None, *args, **kwargs): if c is not None: args.append(c) _plt.args = _plot_args(args, fmt='xyac') - _plot_data(kind='scatter3') + return _plot_data(kind='scatter3') @_close_gks_on_error @@ -1395,7 +1394,7 @@ def isosurface(v, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = [(None, None, v, None, '')] - _plot_data(kind='isosurface') + return _plot_data(kind='isosurface') @_close_gks_on_error @@ -1430,7 +1429,7 @@ def volume(v, **kwargs): _plt.kwargs.update(kwargs) nz, ny, nx = v.shape _plt.args = [(np.arange(nx + 1), np.arange(nz + 1), np.arange(ny + 1), v, '')] - _plot_data(kind='volume') + return _plot_data(kind='volume') @_close_gks_on_error @@ -1457,7 +1456,7 @@ def imshow(image, **kwargs): global _plt _plt.kwargs.update(kwargs) _plt.args = [(None, None, image, None, "")] - _plot_data(kind='imshow') + return _plot_data(kind='imshow') @_close_gks_on_error @@ -1497,7 +1496,7 @@ def size(*size): size = size[0] if len(size) == 0: size = (600, 450) - _plot_data(size=size) + return _plot_data(size=size) @_close_gks_on_error @@ -1519,7 +1518,7 @@ def title(title=""): >>> # Clear the plot title >>> mlab.title() """ - _plot_data(title=title) + return _plot_data(title=title) @_close_gks_on_error @@ -1541,7 +1540,7 @@ def xlabel(x_label=""): >>> # Clear the x-axis label >>> mlab.xlabel() """ - _plot_data(xlabel=x_label) + return _plot_data(xlabel=x_label) @_close_gks_on_error @@ -1563,7 +1562,7 @@ def ylabel(y_label=""): >>> # Clear the y-axis label >>> mlab.ylabel() """ - _plot_data(ylabel=y_label) + return _plot_data(ylabel=y_label) @_close_gks_on_error @@ -1585,7 +1584,7 @@ def zlabel(z_label=""): >>> # Clear the z-axis label >>> mlab.zlabel() """ - _plot_data(zlabel=z_label) + return _plot_data(zlabel=z_label) @_close_gks_on_error @@ -1607,7 +1606,7 @@ def dlabel(d_label=""): >>> # Clear the volume intensity label >>> mlab.dlabel() """ - _plot_data(dlabel=d_label) + return _plot_data(dlabel=d_label) @_close_gks_on_error @@ -1648,7 +1647,7 @@ def xlim(x_min=None, x_max=None, adjust=True): x_min, x_max = x_min except TypeError: pass - _plot_data(xlim=(x_min, x_max), adjust_xlim=adjust) + return _plot_data(xlim=(x_min, x_max), adjust_xlim=adjust) @_close_gks_on_error @@ -1689,7 +1688,7 @@ def ylim(y_min=None, y_max=None, adjust=True): y_min, y_max = y_min except TypeError: pass - _plot_data(ylim=(y_min, y_max), adjust_ylim=adjust) + return _plot_data(ylim=(y_min, y_max), adjust_ylim=adjust) @_close_gks_on_error @@ -1730,7 +1729,7 @@ def zlim(z_min=None, z_max=None, adjust=True): z_min, z_max = z_min except TypeError: pass - _plot_data(zlim=(z_min, z_max), adjust_zlim=adjust) + return _plot_data(zlim=(z_min, z_max), adjust_zlim=adjust) @_close_gks_on_error @@ -1771,7 +1770,7 @@ def rlim(r_min=None, r_max=None, adjust=True): r_min, r_max = r_min except TypeError: pass - _plot_data(rlim=(r_min, r_max), adjust_rlim=adjust) + return _plot_data(rlim=(r_min, r_max), adjust_rlim=adjust) @_close_gks_on_error @@ -1812,7 +1811,7 @@ def philim(phi_min=None, phi_max=None, adjust=True): phi_min, phi_max = phi_min except TypeError: pass - _plot_data(philim=(phi_min, phi_max), adjust_philim=adjust) + return _plot_data(philim=(phi_min, phi_max), adjust_philim=adjust) @_close_gks_on_error @@ -1829,7 +1828,7 @@ def xlog(xlog=True): >>> # Disable it again >>> mlab.xlog(False) """ - _plot_data(xlog=xlog) + return _plot_data(xlog=xlog) @_close_gks_on_error @@ -1846,7 +1845,7 @@ def ylog(ylog=True): >>> # Disable it again >>> mlab.ylog(False) """ - _plot_data(ylog=ylog) + return _plot_data(ylog=ylog) @_close_gks_on_error @@ -1863,7 +1862,7 @@ def zlog(zlog=True): >>> # Disable it again >>> mlab.zlog(False) """ - _plot_data(zlog=zlog) + return _plot_data(zlog=zlog) @_close_gks_on_error @@ -1880,7 +1879,7 @@ def xflip(xflip=True): >>> # Restores the x-axis >>> mlab.xflip(False) """ - _plot_data(xflip=xflip) + return _plot_data(xflip=xflip) @_close_gks_on_error @@ -1897,7 +1896,7 @@ def yflip(yflip=True): >>> # Restores the y-axis >>> mlab.yflip(False) """ - _plot_data(yflip=yflip) + return _plot_data(yflip=yflip) @_close_gks_on_error @@ -1914,7 +1913,7 @@ def zflip(zflip=True): >>> # Restores the z-axis >>> mlab.zflip(False) """ - _plot_data(zflip=zflip) + return _plot_data(zflip=zflip) @_close_gks_on_error @@ -1931,7 +1930,7 @@ def rflip(rflip=True): >>> # Restores the radius >>> mlab.rflip(False) """ - _plot_data(rflip=rflip) + return _plot_data(rflip=rflip) @_close_gks_on_error @@ -1948,7 +1947,7 @@ def phiflip(phiflip=True): >>> # Restores the angles >>> mlab.phiflip(False) """ - _plot_data(phiflip=phiflip) + return _plot_data(phiflip=phiflip) @_close_gks_on_error @@ -1987,7 +1986,7 @@ def colormap(colormap=''): if colormap == '': _set_colormap() return [(c[0], c[1], c[2]) for c in _colormap()] - _plot_data(colormap=colormap) + return _plot_data(colormap=colormap) @_close_gks_on_error @@ -2022,7 +2021,7 @@ def field_of_view(field_of_view): >>> mlab.field_of_view(None) >>> mlab.plot3(x, y, z) """ - _plot_data(field_of_view=field_of_view) + return _plot_data(field_of_view=field_of_view) @_close_gks_on_error @@ -2050,7 +2049,7 @@ def tilt(tilt): >>> mlab.tilt(45) >>> mlab.plot3(x, y, z) """ - _plot_data(tilt=tilt) + return _plot_data(tilt=tilt) @_close_gks_on_error @@ -2080,7 +2079,7 @@ def rotation(rotation): >>> mlab.rotation(45) >>> mlab.plot3(x, y, z) """ - _plot_data(rotation=rotation) + return _plot_data(rotation=rotation) @_close_gks_on_error @@ -2107,7 +2106,7 @@ def legend(*labels, **kwargs): """ if not all(isinstance(label, basestring) for label in labels): raise TypeError('list of strings expected') - _plot_data(labels=labels, **kwargs) + return _plot_data(labels=labels, **kwargs) @_close_gks_on_error