Skip to content

Commit

Permalink
Merge pull request #65 from funkelab/53-add-lineage-tree-view
Browse files Browse the repository at this point in the history
53 add lineage tree view
  • Loading branch information
cmalinmayor authored Sep 11, 2024
2 parents 3a9a03d + ff21b54 commit e1f53d4
Show file tree
Hide file tree
Showing 35 changed files with 1,760 additions and 109 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ by framing the task as an Integer Linear Program (ILP).
See the motile [documentation](https://funkelab.github.io/motile)
for more details on the concepts and method.

Browsing tracking data with interactive lineage tree
![](docs/images/motile_napari_tree_view.gif)

----------------------------------

## Installation
Expand Down
Binary file added docs/images/motile_napari_tree_view.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion scripts/run_hela.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import napari
import zarr
from appdirs import AppDirs
from motile_plugin import MotileWidget
from motile_plugin.widgets import MotileWidget, TreeWidget
from napari.utils.theme import _themes

logging.basicConfig(
Expand Down Expand Up @@ -34,6 +34,8 @@
# Add your custom widget
widget = MotileWidget(viewer)
viewer.window.add_dock_widget(widget, name="Motile")
widget = TreeWidget(viewer)
viewer.window.add_dock_widget(widget, name="Lineage View", area="bottom")

# Start the Napari GUI event loop
napari.run()
6 changes: 1 addition & 5 deletions src/motile_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from importlib.metadata import PackageNotFoundError, version

from .widgets.motile_widget import MotileWidget

try:
__version__ = version("motile-toolbox")
__version__ = version("motile-plugin")
except PackageNotFoundError:
# package is not installed
__version__ = "uninstalled"

__all__ = ("MotileWidget",)
2 changes: 2 additions & 0 deletions src/motile_plugin/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .motile_run import MotileRun # noqa
from .solver_params import SolverParams # noqa
47 changes: 28 additions & 19 deletions src/motile_plugin/backend/motile_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
from pydantic import BaseModel

from motile_plugin.core import Tracks

from .solver_params import SolverParams

STAMP_FORMAT = "%m%d%Y_%H%M%S"
Expand All @@ -26,14 +28,14 @@ class MotileRun(BaseModel):
"""

run_name: str
solver_params: SolverParams
solver_params: SolverParams | None = None
input_segmentation: np.ndarray | None = None
input_points: np.ndarray | None = None
output_segmentation: np.ndarray | None = None
tracks: nx.DiGraph | None = None
tracks: Tracks | None = None
time: datetime = datetime.now()
gaps: list[float] = []
status: str = "done"
scale: list[float] | None = None
# pydantic does not check numpy arrays
model_config = {"arbitrary_types_allowed": True}

Expand Down Expand Up @@ -70,13 +72,17 @@ def _unpack_id(_id: str) -> tuple[datetime, str]:
) from e
return time, run_name

def save(self, base_path: str | Path):
def save(self, base_path: str | Path) -> Path:
"""Save the run in the provided directory. Creates a subdirectory from
the timestamp and run name and stores one file for each element of the
run in that subdirectory.
Args:
base_path (str | Path): The directory to save the run in.
Returns:
(Path): The Path that the run was saved in. The last part of the
path is the directory that was created to store the run.
"""
base_path = Path(base_path)
run_dir = base_path / self._make_id()
Expand All @@ -86,13 +92,15 @@ def save(self, base_path: str | Path):
self._save_array(run_dir, IN_SEG_FILEANME, self.input_segmentation)
if self.input_points is not None:
self._save_array(run_dir, IN_POINTS_FILEANME, self.input_points)
if self.output_segmentation is not None:
self._save_array(
run_dir, OUT_SEG_FILEANME, self.output_segmentation
)
if self.tracks is not None:
self._save_tracks(run_dir)
if self.tracks.segmentation is not None:
self._save_array(
run_dir, OUT_SEG_FILEANME, self.tracks.segmentation
)
if self.tracks.graph is not None:
self._save_tracks_graph(run_dir, self.tracks.graph)
self._save_gaps(run_dir)
return run_dir

@classmethod
def load(cls, run_dir: Path | str, output_required: bool = True):
Expand All @@ -119,25 +127,23 @@ def load(cls, run_dir: Path | str, output_required: bool = True):
input_points = cls._load_array(
run_dir, IN_POINTS_FILEANME, required=False
)
if input_segmentation is None and input_points is None:
raise FileNotFoundError(
f"Must have either input segmentation or points: neither found in {run_dir}"
)
if output_required and input_segmentation is not None:
output_seg_required = True
else:
output_seg_required = False
output_segmentation = cls._load_array(
run_dir, OUT_SEG_FILEANME, required=output_seg_required
)
tracks = cls._load_tracks(run_dir, required=output_required)
tracks_graph = cls._load_tracks_graph(
run_dir, required=output_required
)
tracks = Tracks(graph=tracks_graph, segmentation=output_segmentation)
gaps = cls._load_gaps(run_dir)
return cls(
run_name=run_name,
solver_params=params,
input_segmentation=input_segmentation,
input_points=input_points,
output_segmentation=output_segmentation,
tracks=tracks,
time=time,
gaps=gaps,
Expand Down Expand Up @@ -218,7 +224,7 @@ def _load_array(
else:
return None

def _save_tracks(self, run_dir: Path):
def _save_tracks_graph(self, run_dir: Path, graph: nx.DiGraph):
"""Save the tracks to file. Currently uses networkx node link data
format (and saves it as json).
Expand All @@ -227,10 +233,10 @@ def _save_tracks(self, run_dir: Path):
"""
tracks_file = run_dir / TRACKS_FILENAME
with open(tracks_file, "w") as f:
json.dump(nx.node_link_data(self.tracks), f)
json.dump(nx.node_link_data(graph), f)

@staticmethod
def _load_tracks(
def _load_tracks_graph(
run_dir: Path, required: bool = True
) -> nx.DiGraph | None:
"""Load tracks from file. Currently uses networkx node link data
Expand Down Expand Up @@ -269,7 +275,10 @@ def _load_gaps(run_dir, required: bool = True) -> list[float]:
gaps_file = run_dir / GAPS_FILENAME
if gaps_file.is_file():
with open(gaps_file) as f:
gaps = list(map(float, f.read().split(",")))
file_content = f.read()
if file_content == "":
return []
gaps = list(map(float, file_content.split(",")))
return gaps
elif required:
raise FileNotFoundError(f"No gaps found at {gaps_file}")
Expand Down
3 changes: 2 additions & 1 deletion src/motile_plugin/backend/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_candidate_graph_from_points_list,
graph_to_nx,
)
from motile_toolbox.visualization.napari_utils import assign_tracklet_ids

from .solver_params import SolverParams

Expand Down Expand Up @@ -66,7 +67,7 @@ def solve(

solution_graph = solver.get_selected_subgraph(solution=solution)
solution_nx_graph = graph_to_nx(solution_graph)

solution_nx_graph, _ = assign_tracklet_ids(solution_nx_graph)
return solution_nx_graph


Expand Down
2 changes: 2 additions & 0 deletions src/motile_plugin/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .tracks import Tracks # noqa
from .node_type import NodeType # noqa
11 changes: 11 additions & 0 deletions src/motile_plugin/core/node_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from enum import Enum


class NodeType(Enum):
"""Types of nodes in the track graph. Currently used for standardizing
visualization. All nodes are exactly one type.
"""

SPLIT = "SPLIT"
END = "END"
CONTINUE = "CONTINUE"
81 changes: 81 additions & 0 deletions src/motile_plugin/core/tracks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from motile_toolbox.candidate_graph import NodeAttr
from pydantic import BaseModel

if TYPE_CHECKING:
from typing import Any

import networkx as nx
import numpy as np


class Tracks(BaseModel):
"""A set of tracks consisting of a graph and an optional segmentation.
The graph nodes represent detections and must have a time attribute and
position attribute. Edges in the graph represent links across time.
Attributes:
graph (nx.DiGraph): A graph with nodes representing detections and
and edges representing links across time. Assumed to be "valid"
tracks (e.g., this is not supposed to be a candidate graph),
but the structure is not verified.
segmentation (Optional(np.ndarray)): An optional segmentation that
accompanies the tracking graph. If a segmentation is provided,
it is assumed that the graph has an attribute (default
"seg_id") holding the segmentation id. Defaults to None.
time_attr (str): The attribute in the graph that specifies the time
frame each node is in.
pos_attr (str | tuple[str] | list[str]): The attribute in the graph
that specifies the position of each node. Can be a single attribute
that holds a list, or a list of attribute keys.
"""

graph: nx.DiGraph
segmentation: np.ndarray | None = None
time_attr: str = NodeAttr.TIME.value
pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value
scale: list[float] | None = None
# pydantic does not check numpy arrays
model_config = {"arbitrary_types_allowed": True}

def get_location(self, node: Any, incl_time: bool = False):
"""Get the location of a node in the graph. Optionally include the
time frame as the first dimension. Raises an error if the node
is not in the graph.
Args:
node (Any): The node id in the graph to get the location of.
incl_time (bool, optional): If true, include the time as the
first element of the location array. Defaults to False.
Returns:
list[float]: A list holding the location. If the position
is stored in a single key, the location could be any number
of dimensions.
"""
data = self.graph.nodes[node]
if isinstance(self.pos_attr, (tuple, list)):
pos = [data[dim] for dim in self.pos_attr]
else:
pos = data[self.pos_attr]

if incl_time:
pos = [data[self.time_attr], *pos]

return pos

def get_time(self, node: Any) -> int:
"""Get the time frame of a given node. Raises an error if the node
is not in the graph.
Args:
node (Any): The node id to get the time frame for
Returns:
int: The time frame that the node is in
"""
return self.graph.nodes[node][self.time_attr]
Empty file.
72 changes: 72 additions & 0 deletions src/motile_plugin/layers/track_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import copy

import napari
import networkx as nx
import numpy as np
from motile_toolbox.visualization import to_napari_tracks_layer
from napari.utils import CyclicLabelColormap

from motile_plugin.core import Tracks


class TrackGraph(napari.layers.Tracks):
"""Extended tracks layer that holds the track information and emits and responds
to dynamics visualization signals"""

def __init__(
self,
viewer: napari.Viewer,
tracks: Tracks,
name: str,
colormap: CyclicLabelColormap,
scale: tuple,
):
if tracks is None or tracks.graph is None:
graph = nx.DiGraph()
else:
graph = tracks.graph

track_data, track_props, track_edges = to_napari_tracks_layer(
graph, frame_key=tracks.time_attr, location_key=tracks.pos_attr
)

super().__init__(
data=track_data,
graph=track_edges,
properties=track_props,
name=name,
tail_length=3,
color_by="track_id",
scale=scale,
)

self.viewer = viewer
self.colormaps_dict["track_id"] = colormap

self.tracks_layer_graph = copy.deepcopy(
self.graph
) # for restoring graph later

def update_track_visibility(self, visible: list[int] | str) -> None:
"""Optionally show only the tracks of a current lineage"""

if visible == "all":
self.track_colors[:, 3] = 1
self.graph = self.tracks_layer_graph
else:
track_id_mask = np.isin(
self.properties["track_id"],
visible,
)
self.graph = {
key: self.tracks_layer_graph[key]
for key in visible
if key in self.tracks_layer_graph
}

self.track_colors[:, 3] = 0
self.track_colors[track_id_mask, 3] = 1
if len(self.graph.items()) == 0:
self.display_graph = False # empty dicts to not trigger update (bug?) so disable the graph entirely as a workaround
else:
self.display_graph = True
Loading

0 comments on commit e1f53d4

Please sign in to comment.