diff --git a/.gitignore b/.gitignore index 1c552b2..89dac4c 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,7 @@ _version.py # mpl testing result_images/ + + +# Pycharm paths +.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 782cc34..203d4c7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,24 +2,24 @@ ci: autoupdate_schedule: "quarterly" repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 22.8.0 + rev: 23.1.0 hooks: - id: black - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/nbQA-dev/nbQA - rev: 1.5.2 + rev: 1.6.3 hooks: - id: nbqa-black - id: nbqa-isort @@ -30,18 +30,18 @@ repos: - id: nbstripout - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.0-alpha.0 + rev: v3.0.0-alpha.6 hooks: - id: prettier - repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 + rev: 6.0.0 hooks: - id: flake8 - additional_dependencies: [flake8-typing-imports==1.7.0] + additional_dependencies: [flake8-typing-imports==1.12.0] - repo: https://github.com/PyCQA/autoflake - rev: v1.6.1 + rev: v2.0.1 hooks: - id: autoflake args: diff --git a/docs/contributing.md b/docs/contributing.md index 28cb0c2..3b22b01 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -17,7 +17,7 @@ pre-commit install The `conda env create` command installs all Python packages that are useful when working on the source code of `mpl_interactions` and its documentation. You can also install these packages separately: ```bash -pip install -e .[dev] +pip install -e ".[dev]" ``` The {command}`-e .` flag installs the `mpl_interactions` folder in ["editable" mode](https://pip.pypa.io/en/stable/cli/pip_install/#editable-installs) and {command}`[dev]` installs the [optional dependencies](https://setuptools.readthedocs.io/en/latest/userguide/dependency_management.html#optional-dependencies) you need for developing `mpl_interacions`. diff --git a/docs/examples/custom-callbacks.ipynb b/docs/examples/custom-callbacks.ipynb index d85f20f..5162a58 100644 --- a/docs/examples/custom-callbacks.ipynb +++ b/docs/examples/custom-callbacks.ipynb @@ -157,6 +157,7 @@ "\n", "# attach a custom callback\n", "\n", + "\n", "# if running from a script you can just delete the widgets.Output and associated code\n", "def my_callback(tau, beta):\n", " if tau < 7.5:\n", diff --git a/docs/examples/devlop/devlop-controller.ipynb b/docs/examples/devlop/devlop-controller.ipynb index f593b32..379c310 100644 --- a/docs/examples/devlop/devlop-controller.ipynb +++ b/docs/examples/devlop/devlop-controller.ipynb @@ -310,7 +310,6 @@ " )\n", "\n", " def update(params, indices):\n", - "\n", " # update plot\n", " for i, f in enumerate(funcs):\n", " if x is not None and not indexed_x:\n", @@ -379,7 +378,6 @@ "\n", " lines = []\n", " for i, f in enumerate(funcs):\n", - "\n", " if x is not None and not indexed_x:\n", " lines.append(ax.plot(x, f(x, **params), **plot_kwargs[i])[0])\n", " elif indexed_x:\n", diff --git a/docs/examples/image-segmentation-multiple-images.ipynb b/docs/examples/image-segmentation-multiple-images.ipynb new file mode 100644 index 0000000..8dc6d80 --- /dev/null +++ b/docs/examples/image-segmentation-multiple-images.ipynb @@ -0,0 +1,201 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Image Segmentation of overlayed images and multi-dimensional images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%matplotlib widget\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "import mpl_interactions as mpl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# load a primary sample image\n", + "import urllib\n", + "\n", + "import PIL\n", + "\n", + "url = \"https://github.com/matplotlib/matplotlib/raw/v3.3.0/lib/matplotlib/mpl-data/sample_data/ada.png\"\n", + "\n", + "\n", + "image = np.sum(np.array(PIL.Image.open(urllib.request.urlopen(url))), axis=2) / 255\n", + "# simulating multiple slices along the 2 dimension\n", + "image_stack = np.array([image] * 3).transpose((1, 2, 0))\n", + "\n", + "# secondary image\n", + "url_2 = \"https://github.com/matplotlib/matplotlib/raw/v3.3.0/lib/matplotlib/mpl-data/sample_data/Minduka_Present_Blue_Pack.png\"\n", + "secondary_image = np.sum(np.array(PIL.Image.open(urllib.request.urlopen(url_2))), axis=2) / 255\n", + "# padding secondard to be same size as primary for illustration\n", + "secondary_image_padded = np.pad(\n", + " secondary_image,\n", + " [\n", + " (170, 505),\n", + " (150, 234),\n", + " ],\n", + " mode=\"constant\",\n", + " constant_values=(0, 0),\n", + ")\n", + "# simulating multiple slices along the 2 dimension\n", + "secondary_image_stack = np.array([secondary_image_padded] * 3).transpose(1, 2, 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 3, tight_layout=True)\n", + "ax[0].imshow(image, cmap=\"gray\")\n", + "ax[1].imshow(secondary_image, cmap=\"magma\")\n", + "ax[2].imshow(image, cmap=\"gray\")\n", + "ax[2].imshow(secondary_image_padded, cmap=\"magma\", alpha=0.6)\n", + "ax[0].set_title(\"Primary image\")\n", + "ax[1].set_title(\"Secondary image\")\n", + "ax[2].set_title(\"Overlayed\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Now segment the secondary image over the primary\n", + "1. First create a stack of image_segmenter_overlayed objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "segmenter_stack = mpl.get_segmenter_list(image_stack, secondary_image_stack, n_classes=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. Call draw_masks function to open an interactive window where one can draw the masks on the seperate slices" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "mpl.draw_masks(segmenter_list=segmenter_stack)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "3. Retrieve mask values from segmenter_list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "masks_dict = mpl.get_masks(segmenter_stack, plot_res=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "4. Retrieve contours of the masks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "contours = mpl.get_mask_contours(segmenter_stack)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "5. Plot contours on segmented images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig,ax=plt.subplots(1,image_stack.shape[2])\n", + "\n", + "for n in range(len(image_stack.shape[2])):\n", + " ax[n].imshow(image_stack[:,:,n]) for n in range(len(image_stack.shape[2])\n", + " if len(contours[slice])>0:\n", + " for roi_num in range(len(contours[n])):\n", + " ax.plot(contours[n][roi_num][:,1], contours[n][roi_num][:,0], linewidth=2, color='C'+str(roi_num),label='ROI Nr.'+str(roi_num))\n", + "\n", + "ax.legend()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/examples/usage.ipynb b/docs/examples/usage.ipynb index a76f59f..99381ac 100644 --- a/docs/examples/usage.ipynb +++ b/docs/examples/usage.ipynb @@ -452,6 +452,8 @@ ")\n", "\n", "iplt.title(\"the value of tau is: {tau:.2f}\", controls=controls[\"tau\"])\n", + "\n", + "\n", "# you can still use plt commands if this is the active figure\n", "def ylabel(tau):\n", " return f\"tau/2 is {np.round(tau/2,3)}\"\n", diff --git a/docs/index.md b/docs/index.md index 53c79e1..91a6ea4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -116,6 +116,7 @@ examples/imshow.ipynb examples/hist.ipynb examples/scatter-selector.ipynb examples/image-segmentation.ipynb +examples/image-segmentation-multiple-image.ipynb examples/zoom-factory.ipynb examples/heatmap-slicer.ipynb ``` diff --git a/mpl_interactions/generic.py b/mpl_interactions/generic.py index ccf6c21..b8a167e 100644 --- a/mpl_interactions/generic.py +++ b/mpl_interactions/generic.py @@ -2,7 +2,9 @@ from collections.abc import Callable +import ipywidgets as widgets import numpy as np +from IPython.display import display from matplotlib import __version__ as mpl_version from matplotlib import get_backend from matplotlib.colors import TABLEAU_COLORS, XKCD_COLORS, to_rgba_array @@ -10,6 +12,7 @@ from matplotlib.pyplot import close, ioff, subplots from matplotlib.widgets import LassoSelector from numpy import asanyarray, asarray, max, min +from skimage import measure from .controller import gogogo_controls, prep_scalars from .helpers import ( @@ -27,6 +30,11 @@ "zoom_factory", "panhandler", "image_segmenter", + "image_segmenter_overlayed", + "get_segmenter_list", + "draw_masks", + "get_masks", + "get_mask_contours", "hyperslicer", ] @@ -44,7 +52,6 @@ def heatmap_slicer( figsize=(18, 9), **pcolormesh_kwargs, ): - """ Compare horizontal and/or vertical slices across multiple arrays. @@ -440,7 +447,6 @@ def __init__( """ Create an image segmenter. Any ``kwargs`` will be passed through to the ``imshow`` call that displays *img*. - Parameters ---------- img : array_like @@ -567,6 +573,382 @@ def _ipython_display_(self): display(self.fig.canvas) # noqa: F405, F821 +class image_segmenter_overlayed: + """ + Manually segment an overlay of two images with the lasso selector. + """ + + def __init__( + self, + img, + second_img, + img_extent=None, + second_img_extent=None, + second_img_alpha=0.2, + second_img_cmap="viridis", + nclasses=1, + mask=None, + mask_colors=None, + mask_alpha=0.75, + lineprops=None, + props=None, + lasso_mousebutton="left", + pan_mousebutton="middle", + ax=None, + figsize=(10, 10), + **kwargs, + ): + """ + Create an image segmenter. Any ``kwargs`` will be passed through to the ``imshow`` + call that displays *img*. + + Parameters + ---------- + img : array_like + A valid argument to imshow + nclasses : int, default 1 + second_img : array_like, optional + A valid argument to imshow. Secondary image that is overlayed with a designated alpha value over the primary + image. + img_extent : list, optional + Extent of primary image as given to imshow. + For example: extent = [-15,15,-10,10] + second_img_extent : list, optional + Extent of secondary image as given to imshow. + Similar to img_extent + second_img_alpha : float, optional + transparency of secondary image, by default 0.2 + second_img_cmap : str, optional + colormap of secondary image, by default 'viridis' + mask : arraylike, optional + If you want to pre-seed the mask + mask_colors : None, color, or array of colors, optional + the colors to use for each class. Unselected regions will always be totally transparent + mask_alpha : float, default .75 + The alpha values to use for selected regions. This will always override the alpha values + in mask_colors if any were passed + lineprops : dict, default: None + DEPRECATED - use props instead. + lineprops passed to LassoSelector. If None the default values are: + {"color": "black", "linewidth": 1, "alpha": 0.8} + props : dict, default: None + props passed to LassoSelector. If None the default values are: + {"color": "black", "linewidth": 1, "alpha": 0.8} + lasso_mousebutton : str, or int, default: "left" + The mouse button to use for drawing the selecting lasso. + pan_mousebutton : str, or int, default: "middle" + The button to use for `~mpl_interactions.generic.panhandler`. One of 'left', 'middle' or + 'right', or 1, 2, 3 respectively. + ax : `matplotlib.axes.Axes`, optional + The axis on which to plot. If *None* a new figure will be created. + figsize : (float, float), optional + passed to plt.figure. Ignored if *ax* is given. + **kwargs + All other kwargs will passed to the imshow command for the image + """ + # ensure mask colors is iterable and the same length as the number of classes + # choose colors from default color cycle? + + self.mask_alpha = mask_alpha + + if mask_colors is None: + # this will break if there are more than 10 classes + if nclasses <= 10: + self.mask_colors = to_rgba_array(list(TABLEAU_COLORS)[:nclasses]) + else: + # up to 949 classes. Hopefully that is always enough.... + self.mask_colors = to_rgba_array(list(XKCD_COLORS)[:nclasses]) + else: + self.mask_colors = to_rgba_array(np.atleast_1d(mask_colors)) + # should probably check the shape here + self.mask_colors[:, -1] = self.mask_alpha + + self._img = np.asarray(img) + self.second_img = np.asarray(second_img) + if mask is None: + self.mask = np.zeros(self._img.shape[:2]) + """See :doc:`/examples/image-segmentation`.""" + else: + self.mask = mask + + self._overlay = np.zeros((*self._img.shape[:2], 4)) + self.nclasses = nclasses + for i in range(nclasses + 1): + idx = self.mask == i + if i == 0: + self._overlay[idx] = [0, 0, 0, 0] + else: + self._overlay[idx] = self.mask_colors[i - 1] + if ax is not None: + self.ax = ax + self.fig = self.ax.figure + else: + with ioff(): + self.fig = figure(figsize=figsize) + self.ax = self.fig.gca() + self.displayed = self.ax.imshow(self._img, extent=img_extent, **kwargs) + # plot the secondary image over the primary + self.displayed = self.ax.imshow( + self.second_img, alpha=second_img_alpha, extent=second_img_extent, cmap=second_img_cmap + ) + + self._mask = self.ax.imshow(self._overlay) + + default_props = {"color": "black", "linewidth": 1, "alpha": 0.8} + if (props is None) and (lineprops is None): + props = default_props + elif (lineprops is not None) and (mpl_version >= "3.7"): + print("*lineprops* is deprecated - please use props") + props = {"color": "black", "linewidth": 1, "alpha": 0.8} + + useblit = False if "ipympl" in get_backend().lower() else True + button_dict = {"left": 1, "middle": 2, "right": 3} + if isinstance(pan_mousebutton, str): + pan_mousebutton = button_dict[pan_mousebutton.lower()] + if isinstance(lasso_mousebutton, str): + lasso_mousebutton = button_dict[lasso_mousebutton.lower()] + + if mpl_version < "3.7": + self.lasso = LassoSelector( + self.ax, self._onselect, lineprops=props, useblit=useblit, button=lasso_mousebutton + ) + else: + self.lasso = LassoSelector( + self.ax, self._onselect, props=props, useblit=useblit, button=lasso_mousebutton + ) + self.lasso.set_visible(True) + + pix_x = np.arange(self._img.shape[0]) + pix_y = np.arange(self._img.shape[1]) + xv, yv = np.meshgrid(pix_y, pix_x) + self.pix = np.vstack((xv.flatten(), yv.flatten())).T + + self.ph = panhandler(self.fig, button=pan_mousebutton) + self.disconnect_zoom = zoom_factory(self.ax) + self.current_class = 1 + self.erasing = False + + def _onselect(self, verts): + self.verts = verts + p = Path(verts) + self.indices = p.contains_points(self.pix, radius=0).reshape(self.mask.shape) + if self.erasing: + self.mask[self.indices] = 0 + self._overlay[self.indices] = [0, 0, 0, 0] + else: + self.mask[self.indices] = self.current_class + self._overlay[self.indices] = self.mask_colors[self.current_class - 1] + + self._mask.set_data(self._overlay) + self.fig.canvas.draw_idle() + + def _ipython_display_(self): + display(self.fig.canvas) # noqa: F405, F821 + + +def get_segmenter_list( + primary_image_stack, + secondary_image_stack, + primary_img_extent=None, + secondary_image_ext=None, + overlay=0.25, + secondary_image_cmap="viridis", + n_classes=1, + figsize=(5, 5), +): + """ + Returns a list of image_segmenter_overlayed type entries. + Parameters + ---------- + primary_image_stack: array like + Contains primary images as a 3D array, exemplary shape (128,128,5) for 5 slices + secondary_image_stack: array like + Contains secondary images as a 3D array, has to have same shape[2] as primary_image_stack + primary_img_extent : list, optional + If primary and secondary images don't have the same extent (i.e. shape[0],[1]) this need to be given. + Gets passed to imshow(extent = primary_img_extent), for example something like [-15,15,-10,10] + secondary_image_ext : list, optional + If primary and secondary images don't have the same extent (i.e. shape[0],[1]) this need to be given. + Gets passed to imshow(extent = secondary_image_ext), for example something like [-15,15,-10,10] + overlay : float + alpha value of the secondary image, per default 0.25 + secondary_image_cmap : str, optional + secondary image colormap, per default 'viridis' + n_classes: int, optional + number of channels, per default 1. + Returns + ------- + seg_list: list + Contains image_segmenter_overlayed type objects + """ + line_properties = {"color": "red", "linewidth": 1} + seg_list = [ + image_segmenter_overlayed( + primary_image_stack[:, :, s], + secondary_image_stack[:, :, s], + img_extent=primary_img_extent, + second_img_extent=secondary_image_ext, + second_img_alpha=overlay, + second_img_cmap=secondary_image_cmap, + figsize=figsize, + nclasses=n_classes, + lineprops=None, + props=line_properties, + mask_alpha=0.76, + cmap="gray", + ) + for s in range(primary_image_stack.shape[2]) + ] + return seg_list + + +def draw_masks(segmenter_list, roi_names=None): + """ + Loads a segmenter_list and then allows the user to draw ROIs which are saved in the segmenter_list. + Parameters + ---------- + segmenter_list: list of image_segmenter objects + Usually returned from get_segmenter_list. + roi_names: list of str, optional + Names of ROIs as str for better overview when segmenting multiple ROIs. + """ + + # define image plotting function + def plot_imgs(n_slice, eraser_mode, roi_key): + temp_seg = segmenter_list[n_slice] + temp_seg.erasing = eraser_mode + if roi_names: + # names instead of numbers + # +1 to have the same colorscheme as we start with + # index 0 + roi_number = roi_names.index(roi_key) + 1 + else: + # default numbering + roi_number = roi_key + temp_seg.current_class = roi_number + display(temp_seg) + + n_rois = segmenter_list[0].nclasses + n_slices = len(segmenter_list) + + # Making the UI + if roi_names: + class_selector = widgets.Dropdown(options=roi_names, description="ROI name") + else: + class_selector = widgets.Dropdown( + options=list(range(1, n_rois + 1)), description="ROI number" + ) + + erasing_button = widgets.Checkbox(value=False, description="Erasing") + # create interactive slider for echoes + + slice_slider = widgets.IntSlider( + value=n_slices // 2, min=0, max=n_slices - 1, description="Slice: " + ) + + # put both sliders inside a HBox for nice alignment etc. + ui = widgets.HBox( + [erasing_button, slice_slider, class_selector], + layout=widgets.Layout(display="flex"), + ) + + sliders = widgets.interactive_output( + plot_imgs, + {"n_slice": slice_slider, "eraser_mode": erasing_button, "roi_key": class_selector}, + ) + + display(ui, sliders) + + +def get_masks(segmenter_list, roi_keys=None, plot_res=False): + """ + Extract the masks for a given segmenter list. + + Parameters + --------- + segmenter_list: list of image_segmenter_overlayed objects + roi_keys: list of str, optional + Default is none, then just strings of numbers from 0-number_of_rois are the keys + plot_res: bool, optional. + if one wants the result to be plotted, default is False. + Returns + -------- + mask_per_slice: dict + entries can be called via the selected keys. + + Examples + -------- + If we give keys: + masks = get_masks(segmenter_list,['Tumor','Kidney','Vessel']) + masked_secondary_image = masks['Tumor'] * secondary_image + + If we dont give keys: + masks = ut_anat.get_masks_multi_rois(segmenter_list) + masked_secondary_image = masks['1'] * secondary_image + """ + n_slices = len(segmenter_list) + n_rois = segmenter_list[0].nclasses + if not roi_keys: + # set default names + roi_keys = [str(n) for n in range(1, n_rois + 1)] + else: + # use given keys + pass + mask_per_slice = np.zeros( + (segmenter_list[0].mask.shape[0], segmenter_list[0].mask.shape[1], n_slices, n_rois) + ) + for slic in range(0, n_slices): + for roi in range(0, n_rois): + test_mask = segmenter_list[slic].mask == roi + 1 + mask_per_slice[:, :, slic, roi] = test_mask + + mask_dict = dict() + for idx, roi_key in enumerate(roi_keys): + mask_dict.update({roi_key: mask_per_slice[:, :, :, idx - 1]}) + + if plot_res: + fig, ax = subplots(1, n_rois) + + @widgets.interact(slices=(0, n_slices - 1, 1)) + def update(slices=0): + if n_rois > 1: + [ax[n].imshow(mask_per_slice[:, :, slices, n]) for n in range(n_rois)] + [ax[n].set_title("ROI " + str(roi_keys[n])) for n in range(n_rois)] + else: + ax.imshow(mask_per_slice[:, :, slices, 0]) + ax.set_title("ROI " + str(roi_keys[0])) + + return mask_dict + + +def get_mask_contours(segmenter_list): + """ + Extract the contours of a mask segmented with mpl_interactions image_segmenter_overlayed + + Parameters + --------- + segmenter_list: list of image_segmenter_overlayed objects + + Returns + -------- + contours: list(np.arrays) + + Examples + --------- + To plot the contours on a primary image: + contours = get_roi_coords(seg_list) + fig,ax=subplots(1) + ax.imshow(primary_image) + for contour in contours: + ax.plot(contour[:, 1], contour[:, 0], linewidth=2, color='r') + + """ + contours = [] + for n in range(len(segmenter_list)): + contours.append(measure.find_contours(segmenter_list[n].mask)) + return contours + + def hyperslicer( arr, cmap=None, @@ -594,7 +976,6 @@ def hyperslicer( display_controls=True, **kwargs, ): - """ View slices from a hyperstack of images selected by sliders. Also accepts Xarray.DataArrays in which case the axes names and coordinates will be inferred from the xarray dims and coords. @@ -697,7 +1078,6 @@ def hyperslicer( # Just pass in an array - no kwargs for i in range(arr.ndim - im_dims): - start, stop = None, None name = f"axis{i}" if name in kwargs: diff --git a/setup.cfg b/setup.cfg index a013e51..51d0140 100644 --- a/setup.cfg +++ b/setup.cfg @@ -70,6 +70,7 @@ test = requests scipy xarray + scikit-image dev = %(doc)s %(jupyter)s