Skip to content

Commit

Permalink
Update Widget organization and Div cost
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Jan 25, 2024
1 parent 5ad7310 commit 6886b27
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 232 deletions.
16 changes: 8 additions & 8 deletions scripts/test.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import napari
import zarr

from motile_plugin import ExampleQWidget
from motile_plugin import MotileWidget

# Load Zarr datasets
zarr_directory = "/Users/kharrington/git/cmalinmayor/motile-plugin/data/zarr_data.zarr"
zarr_directory = "/Volumes/funke$/lalitm/cellulus/experiments/data/science_meet/Fluo-N2DL-HeLa.zarr"
zarr_group = zarr.open_group(zarr_directory, mode='r')
image_stack = zarr_group['stack'][:]
labeled_mask = zarr_group['labeled_stack'][:]
labeled_mask = labeled_mask[0:5, :, :]
image_stack = zarr_group['test/raw'][:,0,:]
labeled_mask = zarr_group['post-processed-segmentation'][:,0,:]
labeled_mask = labeled_mask[:, :, :]

# Initialize Napari viewer
viewer = napari.Viewer()

# Add image and label layers to the viewer
# viewer.add_image(image_stack, name='Image Stack')
viewer.add_image(image_stack, name='Image Stack')
viewer.add_labels(labeled_mask, name='Labeled Mask')

# Add your custom widget
widget = ExampleQWidget(viewer)
widget = MotileWidget(viewer)
viewer.window.add_dock_widget(widget)

# Start the Napari GUI event loop
# napari.run()
napari.run()
4 changes: 2 additions & 2 deletions src/motile_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__version__ = "0.0.1"
from ._widget import ExampleQWidget
from ._widget import MotileWidget

__all__ = (
"ExampleQWidget",
"MotileWidget",
)
206 changes: 206 additions & 0 deletions src/motile_plugin/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import math
from pathlib import Path
import numpy as np

from motile import Solver, TrackGraph
from motile.constraints import MaxChildren, MaxParents
from motile.costs import EdgeSelection, Appear, Split
from motile.variables import NodeSelected, EdgeSelected
import networkx as nx
import toml
from tqdm import tqdm
import pprint
import time
from skimage.measure import regionprops
import tifffile
import logging

logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(name)s %(levelname)-8s %(message)s"
)
logger = logging.getLogger(__name__)


def get_location(node_data, loc_keys=("z", "y", "x")):
return [node_data[k] for k in loc_keys]

def get_cand_graph_from_segmentation(
segmentation, max_edge_distance, pos_labels=["y", "x"]
):
"""_summary_
Args:
segmentation (np.array): A numpy array with shape (t, [z,], y, x)
"""
# add nodes
node_frame_dict = (
{}
) # construct a dictionary from time frame to node_id for efficiency
cand_graph = nx.DiGraph()

for t in range(len(segmentation)):
nodes_in_frame = []
props = regionprops(segmentation[t])
for i, regionprop in enumerate(props):
node_id = f"{t}_{regionprop.label}" # TODO: previously node_id= f"{t}_{i}"
attrs = {
"t": t,
"segmentation_id": regionprop.label,
"area": regionprop.area,
}
centroid = regionprop.centroid # [z,] y, x
for label, value in zip(pos_labels, centroid):
attrs[label] = value
cand_graph.add_node(node_id, **attrs)
nodes_in_frame.append(node_id)
node_frame_dict[t] = nodes_in_frame

print(f"Candidate nodes: {cand_graph.number_of_nodes()}")

# add edges
frames = sorted(node_frame_dict.keys())
for frame in tqdm(frames):
if frame + 1 not in node_frame_dict:
continue
next_nodes = node_frame_dict[frame + 1]
next_locs = [
get_location(cand_graph.nodes[n], loc_keys=pos_labels) for n in next_nodes
]
for node in node_frame_dict[frame]:
loc = get_location(cand_graph.nodes[node], loc_keys=pos_labels)
for next_id, next_loc in zip(next_nodes, next_locs):
dist = math.dist(next_loc, loc)
attrs = {
"dist": dist,
}
if dist < max_edge_distance:
cand_graph.add_edge(node, next_id, **attrs)

print(f"Candidate edges: {cand_graph.number_of_edges()}")
return cand_graph



def solve_with_motile(cand_graph, widget):
motile_cand_graph = TrackGraph(cand_graph)
solver = Solver(motile_cand_graph)

solver.add_constraints(MaxChildren(widget.get_max_children()))
solver.add_constraints(MaxParents(widget.get_max_parents()))

if widget.get_distance_weight() is not None:
solver.add_costs(EdgeSelection(widget.get_distance_weight(), attribute="dist", constant=widget.get_distance_offset()))
if widget.get_appear_cost() is not None:
solver.add_costs(Appear(widget.get_appear_cost()))
if widget.get_division_cost() is not None:
print(f"Adding division cost {widget.get_division_cost()}")
solver.add_costs(Split(constant=widget.get_division_cost()))

start_time = time.time()
solution = solver.solve()
print(f"Solution took {time.time() - start_time} seconds")
return solution, solver


def get_solution_nx_graph(solution, solver, cand_graph):
node_selected = solver.get_variables(NodeSelected)
edge_selected = solver.get_variables(EdgeSelected)

selected_nodes = [
node for node in cand_graph.nodes if solution[node_selected[node]] > 0.5
]
selected_edges = [
edge for edge in cand_graph.edges if solution[edge_selected[edge]] > 0.5
]

print(f"Selected nodes: {len(selected_nodes)}")
print(f"Selected edges: {len(selected_edges)}")
solution_graph = nx.edge_subgraph(cand_graph, selected_edges)
return solution_graph



def assign_tracklet_ids(graph):
"""Add a tracklet_id attribute to a graph by removing division edges,
assigning one id to each connected component.
Designed as a helper for visualizing the graph in the napari Tracks layer.
Args:
graph (nx.DiGraph): A networkx graph with a tracking solution
Returns:
nx.DiGraph: The same graph with the tracklet_id assigned. Probably
occurrs in place but returned just to be clear.
"""
graph_copy = graph.copy()

parents = [node for node, degree in graph.out_degree() if degree >= 2]
intertrack_edges = []

# Remove all intertrack edges from a copy of the original graph
for parent in parents:
daughters = [child for p, child in graph.out_edges(parent)]
for daughter in daughters:
graph_copy.remove_edge(parent, daughter)
intertrack_edges.append((parent, daughter))

track_id = 0
for tracklet in nx.weakly_connected_components(graph_copy):
nx.set_node_attributes(
graph, {node: {"tracklet_id": track_id} for node in tracklet}
)
track_id += 1
return graph, intertrack_edges


def to_napari_tracks_layer(
graph, frame_key="t", location_keys=("y", "x"), properties=()
):
"""Function to take a networkx graph and return the data needed to add to
a napari tracks layer.
Args:
graph (nx.DiGraph): _description_
frame_key (str, optional): Key in graph attributes containing time frame.
Defaults to "t".
location_keys (tuple, optional): Keys in graph node attributes containing
location. Should be in order: (Z), Y, X. Defaults to ("y", "x").
properties (tuple, optional): Keys in graph node attributes to add
to the visualization layer. Defaults to (). NOTE: not working now :(
Returns:
data : array (N, D+1)
Coordinates for N points in D+1 dimensions. ID,T,(Z),Y,X. The first
axis is the integer ID of the track. D is either 3 or 4 for planar
or volumetric timeseries respectively.
properties : dict {str: array (N,)}
Properties for each point. Each property should be an array of length N,
where N is the number of points.
graph : dict {int: list}
Graph representing associations between tracks. Dictionary defines the
mapping between a track ID and the parents of the track. This can be
one (the track has one parent, and the parent has >=1 child) in the
case of track splitting, or more than one (the track has multiple
parents, but only one child) in the case of track merging.
"""
napari_data = np.zeros((graph.number_of_nodes(), len(location_keys) + 2))
napari_properties = {prop: np.zeros(graph.number_of_nodes()) for prop in properties}
napari_edges = {}
graph, intertrack_edges = assign_tracklet_ids(graph)
for index, node in enumerate(graph.nodes(data=True)):
node_id, data = node
location = [data[loc_key] for loc_key in location_keys]
napari_data[index] = [data["tracklet_id"], data[frame_key]] + location
for prop in properties:
if prop in data:
napari_properties[prop][index] = data[prop]
napari_edges = {}
for parent, child in intertrack_edges:
parent_track_id = graph.nodes[parent]["tracklet_id"]
child_track_id = graph.nodes[child]["tracklet_id"]
if child_track_id in napari_edges:
napari_edges[child_track_id].append(parent_track_id)
else:
napari_edges[child_track_id] = [parent_track_id]
return napari_data, napari_properties, napari_edges

Loading

0 comments on commit 6886b27

Please sign in to comment.