From 6886b272fa38c95bec8de34546846e1e478e5eb4 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 24 Jan 2024 22:04:51 -0500 Subject: [PATCH] Update Widget organization and Div cost --- scripts/test.py | 16 +- src/motile_plugin/__init__.py | 4 +- src/motile_plugin/_utils.py | 206 ++++++++++++++++++++++++++ src/motile_plugin/_widget.py | 270 ++++++---------------------------- 4 files changed, 264 insertions(+), 232 deletions(-) create mode 100644 src/motile_plugin/_utils.py diff --git a/scripts/test.py b/scripts/test.py index 2cdf0ac..5001a45 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -1,25 +1,25 @@ import napari import zarr -from motile_plugin import ExampleQWidget +from motile_plugin import MotileWidget # Load Zarr datasets -zarr_directory = "/Users/kharrington/git/cmalinmayor/motile-plugin/data/zarr_data.zarr" +zarr_directory = "/Volumes/funke$/lalitm/cellulus/experiments/data/science_meet/Fluo-N2DL-HeLa.zarr" zarr_group = zarr.open_group(zarr_directory, mode='r') -image_stack = zarr_group['stack'][:] -labeled_mask = zarr_group['labeled_stack'][:] -labeled_mask = labeled_mask[0:5, :, :] +image_stack = zarr_group['test/raw'][:,0,:] +labeled_mask = zarr_group['post-processed-segmentation'][:,0,:] +labeled_mask = labeled_mask[:, :, :] # Initialize Napari viewer viewer = napari.Viewer() # Add image and label layers to the viewer -# viewer.add_image(image_stack, name='Image Stack') +viewer.add_image(image_stack, name='Image Stack') viewer.add_labels(labeled_mask, name='Labeled Mask') # Add your custom widget -widget = ExampleQWidget(viewer) +widget = MotileWidget(viewer) viewer.window.add_dock_widget(widget) # Start the Napari GUI event loop -# napari.run() +napari.run() diff --git a/src/motile_plugin/__init__.py b/src/motile_plugin/__init__.py index 40527f6..5505a8c 100644 --- a/src/motile_plugin/__init__.py +++ b/src/motile_plugin/__init__.py @@ -1,6 +1,6 @@ __version__ = "0.0.1" -from ._widget import ExampleQWidget +from ._widget import MotileWidget __all__ = ( - "ExampleQWidget", + "MotileWidget", ) diff --git a/src/motile_plugin/_utils.py b/src/motile_plugin/_utils.py new file mode 100644 index 0000000..a70bc88 --- /dev/null +++ b/src/motile_plugin/_utils.py @@ -0,0 +1,206 @@ +import math +from pathlib import Path +import numpy as np + +from motile import Solver, TrackGraph +from motile.constraints import MaxChildren, MaxParents +from motile.costs import EdgeSelection, Appear, Split +from motile.variables import NodeSelected, EdgeSelected +import networkx as nx +import toml +from tqdm import tqdm +import pprint +import time +from skimage.measure import regionprops +import tifffile +import logging + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s %(name)s %(levelname)-8s %(message)s" +) +logger = logging.getLogger(__name__) + + +def get_location(node_data, loc_keys=("z", "y", "x")): + return [node_data[k] for k in loc_keys] + +def get_cand_graph_from_segmentation( + segmentation, max_edge_distance, pos_labels=["y", "x"] +): + """_summary_ + + Args: + segmentation (np.array): A numpy array with shape (t, [z,], y, x) + """ + # add nodes + node_frame_dict = ( + {} + ) # construct a dictionary from time frame to node_id for efficiency + cand_graph = nx.DiGraph() + + for t in range(len(segmentation)): + nodes_in_frame = [] + props = regionprops(segmentation[t]) + for i, regionprop in enumerate(props): + node_id = f"{t}_{regionprop.label}" # TODO: previously node_id= f"{t}_{i}" + attrs = { + "t": t, + "segmentation_id": regionprop.label, + "area": regionprop.area, + } + centroid = regionprop.centroid # [z,] y, x + for label, value in zip(pos_labels, centroid): + attrs[label] = value + cand_graph.add_node(node_id, **attrs) + nodes_in_frame.append(node_id) + node_frame_dict[t] = nodes_in_frame + + print(f"Candidate nodes: {cand_graph.number_of_nodes()}") + + # add edges + frames = sorted(node_frame_dict.keys()) + for frame in tqdm(frames): + if frame + 1 not in node_frame_dict: + continue + next_nodes = node_frame_dict[frame + 1] + next_locs = [ + get_location(cand_graph.nodes[n], loc_keys=pos_labels) for n in next_nodes + ] + for node in node_frame_dict[frame]: + loc = get_location(cand_graph.nodes[node], loc_keys=pos_labels) + for next_id, next_loc in zip(next_nodes, next_locs): + dist = math.dist(next_loc, loc) + attrs = { + "dist": dist, + } + if dist < max_edge_distance: + cand_graph.add_edge(node, next_id, **attrs) + + print(f"Candidate edges: {cand_graph.number_of_edges()}") + return cand_graph + + + +def solve_with_motile(cand_graph, widget): + motile_cand_graph = TrackGraph(cand_graph) + solver = Solver(motile_cand_graph) + + solver.add_constraints(MaxChildren(widget.get_max_children())) + solver.add_constraints(MaxParents(widget.get_max_parents())) + + if widget.get_distance_weight() is not None: + solver.add_costs(EdgeSelection(widget.get_distance_weight(), attribute="dist", constant=widget.get_distance_offset())) + if widget.get_appear_cost() is not None: + solver.add_costs(Appear(widget.get_appear_cost())) + if widget.get_division_cost() is not None: + print(f"Adding division cost {widget.get_division_cost()}") + solver.add_costs(Split(constant=widget.get_division_cost())) + + start_time = time.time() + solution = solver.solve() + print(f"Solution took {time.time() - start_time} seconds") + return solution, solver + + +def get_solution_nx_graph(solution, solver, cand_graph): + node_selected = solver.get_variables(NodeSelected) + edge_selected = solver.get_variables(EdgeSelected) + + selected_nodes = [ + node for node in cand_graph.nodes if solution[node_selected[node]] > 0.5 + ] + selected_edges = [ + edge for edge in cand_graph.edges if solution[edge_selected[edge]] > 0.5 + ] + + print(f"Selected nodes: {len(selected_nodes)}") + print(f"Selected edges: {len(selected_edges)}") + solution_graph = nx.edge_subgraph(cand_graph, selected_edges) + return solution_graph + + + +def assign_tracklet_ids(graph): + """Add a tracklet_id attribute to a graph by removing division edges, + assigning one id to each connected component. + Designed as a helper for visualizing the graph in the napari Tracks layer. + + Args: + graph (nx.DiGraph): A networkx graph with a tracking solution + + Returns: + nx.DiGraph: The same graph with the tracklet_id assigned. Probably + occurrs in place but returned just to be clear. + """ + graph_copy = graph.copy() + + parents = [node for node, degree in graph.out_degree() if degree >= 2] + intertrack_edges = [] + + # Remove all intertrack edges from a copy of the original graph + for parent in parents: + daughters = [child for p, child in graph.out_edges(parent)] + for daughter in daughters: + graph_copy.remove_edge(parent, daughter) + intertrack_edges.append((parent, daughter)) + + track_id = 0 + for tracklet in nx.weakly_connected_components(graph_copy): + nx.set_node_attributes( + graph, {node: {"tracklet_id": track_id} for node in tracklet} + ) + track_id += 1 + return graph, intertrack_edges + + +def to_napari_tracks_layer( + graph, frame_key="t", location_keys=("y", "x"), properties=() +): + """Function to take a networkx graph and return the data needed to add to + a napari tracks layer. + + Args: + graph (nx.DiGraph): _description_ + frame_key (str, optional): Key in graph attributes containing time frame. + Defaults to "t". + location_keys (tuple, optional): Keys in graph node attributes containing + location. Should be in order: (Z), Y, X. Defaults to ("y", "x"). + properties (tuple, optional): Keys in graph node attributes to add + to the visualization layer. Defaults to (). NOTE: not working now :( + + Returns: + data : array (N, D+1) + Coordinates for N points in D+1 dimensions. ID,T,(Z),Y,X. The first + axis is the integer ID of the track. D is either 3 or 4 for planar + or volumetric timeseries respectively. + properties : dict {str: array (N,)} + Properties for each point. Each property should be an array of length N, + where N is the number of points. + graph : dict {int: list} + Graph representing associations between tracks. Dictionary defines the + mapping between a track ID and the parents of the track. This can be + one (the track has one parent, and the parent has >=1 child) in the + case of track splitting, or more than one (the track has multiple + parents, but only one child) in the case of track merging. + """ + napari_data = np.zeros((graph.number_of_nodes(), len(location_keys) + 2)) + napari_properties = {prop: np.zeros(graph.number_of_nodes()) for prop in properties} + napari_edges = {} + graph, intertrack_edges = assign_tracklet_ids(graph) + for index, node in enumerate(graph.nodes(data=True)): + node_id, data = node + location = [data[loc_key] for loc_key in location_keys] + napari_data[index] = [data["tracklet_id"], data[frame_key]] + location + for prop in properties: + if prop in data: + napari_properties[prop][index] = data[prop] + napari_edges = {} + for parent, child in intertrack_edges: + parent_track_id = graph.nodes[parent]["tracklet_id"] + child_track_id = graph.nodes[child]["tracklet_id"] + if child_track_id in napari_edges: + napari_edges[child_track_id].append(parent_track_id) + else: + napari_edges[child_track_id] = [parent_track_id] + return napari_data, napari_properties, napari_edges + \ No newline at end of file diff --git a/src/motile_plugin/_widget.py b/src/motile_plugin/_widget.py index 0e4e925..3899d92 100644 --- a/src/motile_plugin/_widget.py +++ b/src/motile_plugin/_widget.py @@ -1,45 +1,17 @@ -""" -This module contains four napari widgets declared in -different ways: - -- a pure Python function flagged with `autogenerate: true` - in the plugin manifest. Type annotations are used by - magicgui to generate widgets for each parameter. Best - suited for simple processing tasks - usually taking - in and/or returning a layer. -- a `magic_factory` decorated function. The `magic_factory` - decorator allows us to customize aspects of the resulting - GUI, including the widgets associated with each parameter. - Best used when you have a very simple processing task, - but want some control over the autogenerated widgets. If you - find yourself needing to define lots of nested functions to achieve - your functionality, maybe look at the `Container` widget! -- a `magicgui.widgets.Container` subclass. This provides lots - of flexibility and customization options while still supporting - `magicgui` widgets and convenience methods for creating widgets - from type annotations. If you want to customize your widgets and - connect callbacks, this is the best widget option for you. -- a `QWidget` subclass. This provides maximal flexibility but requires - full specification of widget layouts, callbacks, events, etc. - -References: -- Widget specification: https://napari.org/stable/plugins/guides.html?#widgets -- magicgui docs: https://pyapp-kit.github.io/magicgui/ - -Replace code below according to your needs. -""" from typing import TYPE_CHECKING from magicgui import magic_factory from magicgui.widgets import CheckBox, Container, create_widget from qtpy.QtWidgets import ( QWidget, QPushButton, QSlider, QHBoxLayout, QVBoxLayout, - QLabel, QSpinBox, QCheckBox, QDoubleSpinBox, QGroupBox, QLineEdit + QLabel, QSpinBox, QCheckBox, QDoubleSpinBox, QGroupBox, QLineEdit, + QComboBox ) from qtpy.QtCore import Qt from skimage.util import img_as_float #from napari_graph import UndirectedGraph import numpy as np +from napari.layers import Labels import pandas as pd if TYPE_CHECKING: @@ -63,6 +35,7 @@ import tifffile import logging +from ._utils import get_cand_graph_from_segmentation, solve_with_motile, get_solution_nx_graph, to_napari_tracks_layer logging.basicConfig( level=logging.INFO, format="%(asctime)s %(name)s %(levelname)-8s %(message)s" ) @@ -70,203 +43,34 @@ # logger.setLevel(logging.DEBUG) # logging.getLogger('traccuracy.matchers._ctc').setLevel(logging.DEBUG) - - - -def get_cand_graph_from_segmentation( - segmentation, max_edge_distance, pos_labels=["y", "x"] -): - """_summary_ - - Args: - segmentation (np.array): A numpy array with shape (t, [z,], y, x) - """ - # add nodes - node_frame_dict = ( - {} - ) # construct a dictionary from time frame to node_id for efficiency - cand_graph = nx.DiGraph() - - for t in range(len(segmentation)): - nodes_in_frame = [] - props = regionprops(segmentation[t]) - for i, regionprop in enumerate(props): - node_id = f"{t}_{regionprop.label}" # TODO: previously node_id= f"{t}_{i}" - attrs = { - "t": t, - "segmentation_id": regionprop.label, - "area": regionprop.area, - } - centroid = regionprop.centroid # [z,] y, x - for label, value in zip(pos_labels, centroid): - attrs[label] = value - cand_graph.add_node(node_id, **attrs) - nodes_in_frame.append(node_id) - node_frame_dict[t] = nodes_in_frame - - print(f"Candidate nodes: {cand_graph.number_of_nodes()}") - - # add edges - frames = sorted(node_frame_dict.keys()) - for frame in tqdm(frames): - if frame + 1 not in node_frame_dict: - continue - next_nodes = node_frame_dict[frame + 1] - next_locs = [ - get_location(cand_graph.nodes[n], loc_keys=pos_labels) for n in next_nodes - ] - for node in node_frame_dict[frame]: - loc = get_location(cand_graph.nodes[node], loc_keys=pos_labels) - for next_id, next_loc in zip(next_nodes, next_locs): - dist = math.dist(next_loc, loc) - attrs = { - "dist": dist, - } - if dist < max_edge_distance: - cand_graph.add_edge(node, next_id, **attrs) - - print(f"Candidate edges: {cand_graph.number_of_edges()}") - return cand_graph - - -def get_location(node_data, loc_keys=("z", "y", "x")): - return [node_data[k] for k in loc_keys] - - -def solve_with_motile(cand_graph, widget): - motile_cand_graph = TrackGraph(cand_graph) - solver = Solver(motile_cand_graph) - - solver.add_constraints(MaxChildren(widget.get_max_children())) - solver.add_constraints(MaxParents(widget.get_max_parents())) - - if widget.get_distance_weight() is not None: - solver.add_costs(EdgeSelection(widget.get_distance_weight(), attribute="dist", constant=widget.get_distance_offset())) - if widget.get_appear_cost() is not None: - solver.add_costs(Appear(widget.get_appear_cost())) - - start_time = time.time() - solution = solver.solve() - print(f"Solution took {time.time() - start_time} seconds") - return solution, solver - - -def get_solution_nx_graph(solution, solver, cand_graph): - node_selected = solver.get_variables(NodeSelected) - edge_selected = solver.get_variables(EdgeSelected) - - selected_nodes = [ - node for node in cand_graph.nodes if solution[node_selected[node]] > 0.5 - ] - selected_edges = [ - edge for edge in cand_graph.edges if solution[edge_selected[edge]] > 0.5 - ] - - print(f"Selected nodes: {len(selected_nodes)}") - print(f"Selected edges: {len(selected_edges)}") - solution_graph = nx.edge_subgraph(cand_graph, selected_edges) - return solution_graph - - -def assign_tracklet_ids(graph): - """Add a tracklet_id attribute to a graph by removing division edges, - assigning one id to each connected component. - Designed as a helper for visualizing the graph in the napari Tracks layer. - - Args: - graph (nx.DiGraph): A networkx graph with a tracking solution - - Returns: - nx.DiGraph: The same graph with the tracklet_id assigned. Probably - occurrs in place but returned just to be clear. - """ - graph_copy = graph.copy() - - parents = [node for node, degree in graph.out_degree() if degree >= 2] - intertrack_edges = [] - - # Remove all intertrack edges from a copy of the original graph - for parent in parents: - daughters = [child for p, child in graph.out_edges(parent)] - for daughter in daughters: - graph_copy.remove_edge(parent, daughter) - intertrack_edges.append((parent, daughter)) - - track_id = 0 - for tracklet in nx.weakly_connected_components(graph_copy): - nx.set_node_attributes( - graph, {node: {"tracklet_id": track_id} for node in tracklet} - ) - track_id += 1 - return graph, intertrack_edges - - -def to_napari_tracks_layer( - graph, frame_key="t", location_keys=("y", "x"), properties=() -): - """Function to take a networkx graph and return the data needed to add to - a napari tracks layer. - - Args: - graph (nx.DiGraph): _description_ - frame_key (str, optional): Key in graph attributes containing time frame. - Defaults to "t". - location_keys (tuple, optional): Keys in graph node attributes containing - location. Should be in order: (Z), Y, X. Defaults to ("y", "x"). - properties (tuple, optional): Keys in graph node attributes to add - to the visualization layer. Defaults to (). NOTE: not working now :( - - Returns: - data : array (N, D+1) - Coordinates for N points in D+1 dimensions. ID,T,(Z),Y,X. The first - axis is the integer ID of the track. D is either 3 or 4 for planar - or volumetric timeseries respectively. - properties : dict {str: array (N,)} - Properties for each point. Each property should be an array of length N, - where N is the number of points. - graph : dict {int: list} - Graph representing associations between tracks. Dictionary defines the - mapping between a track ID and the parents of the track. This can be - one (the track has one parent, and the parent has >=1 child) in the - case of track splitting, or more than one (the track has multiple - parents, but only one child) in the case of track merging. - """ - napari_data = np.zeros((graph.number_of_nodes(), len(location_keys) + 2)) - napari_properties = {prop: np.zeros(graph.number_of_nodes()) for prop in properties} - napari_edges = {} - graph, intertrack_edges = assign_tracklet_ids(graph) - for index, node in enumerate(graph.nodes(data=True)): - node_id, data = node - location = [data[loc_key] for loc_key in location_keys] - napari_data[index] = [data["tracklet_id"], data[frame_key]] + location - for prop in properties: - if prop in data: - napari_properties[prop][index] = data[prop] - napari_edges = {} - for parent, child in intertrack_edges: - parent_track_id = graph.nodes[parent]["tracklet_id"] - child_track_id = graph.nodes[child]["tracklet_id"] - if child_track_id in napari_edges: - napari_edges[child_track_id].append(parent_track_id) - else: - napari_edges[child_track_id] = [parent_track_id] - return napari_data, napari_properties, napari_edges - - -class ExampleQWidget(QWidget): +class MotileWidget(QWidget): def __init__(self, viewer: "napari.viewer.Viewer"): super().__init__() self.viewer = viewer main_layout = QVBoxLayout() + # Select Labels layer + layer_group = QGroupBox("Select Input Layer") + layer_layout = QHBoxLayout() + self.layer_selection_box = QComboBox() + for layer in viewer.layers: + if isinstance(layer, Labels): + self.layer_selection_box.addItem(layer.name) + if len(self.layer_selection_box) == 0: + self.layer_selection_box.addItem("None") + layer_layout.addWidget(self.layer_selection_box) + layer_group.setLayout(layer_layout) + main_layout.addWidget(layer_group) + # Data-specific Hyperparameters section hyperparameters_group = QGroupBox("Data-specific Hyperparameters") hyperparameters_layout = QHBoxLayout() self.max_edge_distance_spinbox = QDoubleSpinBox() self.max_edge_distance_spinbox.setValue(50) + self.max_edge_distance_spinbox.setRange(1, 1e10) hyperparameters_layout.addWidget(QLabel("Max Edge Distance:")) hyperparameters_layout.addWidget(self.max_edge_distance_spinbox) hyperparameters_group.setLayout(hyperparameters_layout) @@ -292,9 +96,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"): constant_costs_layout = QHBoxLayout() self.appear_spinbox = QDoubleSpinBox() self.appear_spinbox.setValue(30) + self.appear_spinbox.setRange(0.0, 1e10) self.appear_checkbox = QCheckBox("Appear") self.appear_checkbox.setChecked(True) self.division_spinbox = QDoubleSpinBox() + self.division_spinbox.setRange(0.0, 1e10) self.division_checkbox = QCheckBox("Division") constant_costs_layout.addWidget(self.appear_checkbox) constant_costs_layout.addWidget(self.appear_spinbox) @@ -312,9 +118,10 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self.distance_checkbox = QCheckBox("Distance") self.distance_checkbox.setChecked(True) self.distance_weight_spinbox = QDoubleSpinBox() + self.distance_weight_spinbox.setRange(-1e10, 1e10) self.distance_weight_spinbox.setValue(1) self.distance_offset_spinbox = QDoubleSpinBox() - self.distance_offset_spinbox.setMinimum(-99.0) + self.distance_offset_spinbox.setRange(-1e10, 1e10) self.distance_offset_spinbox.setValue(-20) distance_layout.addWidget(self.distance_checkbox) distance_layout.addWidget(QLabel("Weight:")) @@ -324,7 +131,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"): feature_costs_layout.addLayout(distance_layout) # IOU row - iou_layout = QHBoxLayout() + """iou_layout = QHBoxLayout() self.iou_checkbox = QCheckBox("IOU") self.iou_weight_spinbox = QDoubleSpinBox() self.iou_offset_spinbox = QDoubleSpinBox() @@ -333,15 +140,24 @@ def __init__(self, viewer: "napari.viewer.Viewer"): iou_layout.addWidget(self.iou_weight_spinbox) iou_layout.addWidget(QLabel("Offset:")) iou_layout.addWidget(self.iou_offset_spinbox) - feature_costs_layout.addLayout(iou_layout) + feature_costs_layout.addLayout(iou_layout)""" feature_costs_group.setLayout(feature_costs_layout) main_layout.addWidget(feature_costs_group) + # Specify name text box + run_group = QGroupBox("Run") + run_layout = QHBoxLayout() + self.run_name = QLineEdit("tracks") + run_layout.addWidget(QLabel("Run name:")) + run_layout.addWidget(self.run_name) + # Generate Tracks button generate_tracks_btn = QPushButton("Generate Tracks") generate_tracks_btn.clicked.connect(self._on_click_generate_tracks) - main_layout.addWidget(generate_tracks_btn) + run_layout.addWidget(generate_tracks_btn) + run_group.setLayout(run_layout) + main_layout.addWidget(run_group) # Original layout elements # btn = QPushButton("Generate Graph") @@ -391,11 +207,21 @@ def get_iou_weight(self): def get_iou_offset(self): return self.iou_offset_spinbox.value() if self.iou_checkbox.isChecked() else None + def get_run_name(self): + return self.run_name.text() + + def get_labels_layer(self): + curr_text = self.layer_selection_box.currentText() + if curr_text == "None": + return None + return self.viewer.layers[curr_text] def _on_click_generate_tracks(self): # Logic for generating tracks - - segmentation = self.viewer.layers['Labeled Mask'].data + labels_layer = self.get_labels_layer() + if labels_layer is None: + return + segmentation = labels_layer.data print(f"Segmentation shape: {segmentation.shape}") cand_graph = get_cand_graph_from_segmentation(segmentation, self.get_max_edge_distance()) @@ -405,4 +231,4 @@ def _on_click_generate_tracks(self): solution_nx_graph = get_solution_nx_graph(solution, solver, cand_graph) track_data, track_props, track_edges = to_napari_tracks_layer(solution_nx_graph) - self.viewer.add_tracks(track_data, properties=track_props, graph=track_edges, name="Yay!") + self.viewer.add_tracks(track_data, properties=track_props, graph=track_edges, name=self.get_run_name())