diff --git a/README.md b/README.md index 3dfc9ba..15c52d0 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,9 @@ by framing the task as an Integer Linear Program (ILP). See the motile [documentation](https://funkelab.github.io/motile) for more details on the concepts and method. +Browsing tracking data with interactive lineage tree +![](docs/images/motile_napari_tree_view.gif) + ---------------------------------- ## Installation diff --git a/docs/images/motile_napari_tree_view.gif b/docs/images/motile_napari_tree_view.gif new file mode 100644 index 0000000..e326e50 Binary files /dev/null and b/docs/images/motile_napari_tree_view.gif differ diff --git a/scripts/run_hela.py b/scripts/run_hela.py index 67a6f35..60aad19 100644 --- a/scripts/run_hela.py +++ b/scripts/run_hela.py @@ -4,7 +4,7 @@ import napari import zarr from appdirs import AppDirs -from motile_plugin import MotileWidget +from motile_plugin.widgets import MotileWidget, TreeWidget from napari.utils.theme import _themes logging.basicConfig( @@ -34,6 +34,8 @@ # Add your custom widget widget = MotileWidget(viewer) viewer.window.add_dock_widget(widget, name="Motile") +widget = TreeWidget(viewer) +viewer.window.add_dock_widget(widget, name="Lineage View", area="bottom") # Start the Napari GUI event loop napari.run() diff --git a/src/motile_plugin/__init__.py b/src/motile_plugin/__init__.py index 76817b5..a6b430c 100644 --- a/src/motile_plugin/__init__.py +++ b/src/motile_plugin/__init__.py @@ -1,11 +1,7 @@ from importlib.metadata import PackageNotFoundError, version -from .widgets.motile_widget import MotileWidget - try: - __version__ = version("motile-toolbox") + __version__ = version("motile-plugin") except PackageNotFoundError: # package is not installed __version__ = "uninstalled" - -__all__ = ("MotileWidget",) diff --git a/src/motile_plugin/backend/__init__.py b/src/motile_plugin/backend/__init__.py index e69de29..c51f27c 100644 --- a/src/motile_plugin/backend/__init__.py +++ b/src/motile_plugin/backend/__init__.py @@ -0,0 +1,2 @@ +from .motile_run import MotileRun # noqa +from .solver_params import SolverParams # noqa diff --git a/src/motile_plugin/backend/motile_run.py b/src/motile_plugin/backend/motile_run.py index 356c60a..4ee067a 100644 --- a/src/motile_plugin/backend/motile_run.py +++ b/src/motile_plugin/backend/motile_run.py @@ -6,6 +6,8 @@ import numpy as np from pydantic import BaseModel +from motile_plugin.core import Tracks + from .solver_params import SolverParams STAMP_FORMAT = "%m%d%Y_%H%M%S" @@ -26,14 +28,14 @@ class MotileRun(BaseModel): """ run_name: str - solver_params: SolverParams + solver_params: SolverParams | None = None input_segmentation: np.ndarray | None = None input_points: np.ndarray | None = None - output_segmentation: np.ndarray | None = None - tracks: nx.DiGraph | None = None + tracks: Tracks | None = None time: datetime = datetime.now() gaps: list[float] = [] status: str = "done" + scale: list[float] | None = None # pydantic does not check numpy arrays model_config = {"arbitrary_types_allowed": True} @@ -70,13 +72,17 @@ def _unpack_id(_id: str) -> tuple[datetime, str]: ) from e return time, run_name - def save(self, base_path: str | Path): + def save(self, base_path: str | Path) -> Path: """Save the run in the provided directory. Creates a subdirectory from the timestamp and run name and stores one file for each element of the run in that subdirectory. Args: base_path (str | Path): The directory to save the run in. + + Returns: + (Path): The Path that the run was saved in. The last part of the + path is the directory that was created to store the run. """ base_path = Path(base_path) run_dir = base_path / self._make_id() @@ -86,13 +92,15 @@ def save(self, base_path: str | Path): self._save_array(run_dir, IN_SEG_FILEANME, self.input_segmentation) if self.input_points is not None: self._save_array(run_dir, IN_POINTS_FILEANME, self.input_points) - if self.output_segmentation is not None: - self._save_array( - run_dir, OUT_SEG_FILEANME, self.output_segmentation - ) if self.tracks is not None: - self._save_tracks(run_dir) + if self.tracks.segmentation is not None: + self._save_array( + run_dir, OUT_SEG_FILEANME, self.tracks.segmentation + ) + if self.tracks.graph is not None: + self._save_tracks_graph(run_dir, self.tracks.graph) self._save_gaps(run_dir) + return run_dir @classmethod def load(cls, run_dir: Path | str, output_required: bool = True): @@ -119,10 +127,6 @@ def load(cls, run_dir: Path | str, output_required: bool = True): input_points = cls._load_array( run_dir, IN_POINTS_FILEANME, required=False ) - if input_segmentation is None and input_points is None: - raise FileNotFoundError( - f"Must have either input segmentation or points: neither found in {run_dir}" - ) if output_required and input_segmentation is not None: output_seg_required = True else: @@ -130,14 +134,16 @@ def load(cls, run_dir: Path | str, output_required: bool = True): output_segmentation = cls._load_array( run_dir, OUT_SEG_FILEANME, required=output_seg_required ) - tracks = cls._load_tracks(run_dir, required=output_required) + tracks_graph = cls._load_tracks_graph( + run_dir, required=output_required + ) + tracks = Tracks(graph=tracks_graph, segmentation=output_segmentation) gaps = cls._load_gaps(run_dir) return cls( run_name=run_name, solver_params=params, input_segmentation=input_segmentation, input_points=input_points, - output_segmentation=output_segmentation, tracks=tracks, time=time, gaps=gaps, @@ -218,7 +224,7 @@ def _load_array( else: return None - def _save_tracks(self, run_dir: Path): + def _save_tracks_graph(self, run_dir: Path, graph: nx.DiGraph): """Save the tracks to file. Currently uses networkx node link data format (and saves it as json). @@ -227,10 +233,10 @@ def _save_tracks(self, run_dir: Path): """ tracks_file = run_dir / TRACKS_FILENAME with open(tracks_file, "w") as f: - json.dump(nx.node_link_data(self.tracks), f) + json.dump(nx.node_link_data(graph), f) @staticmethod - def _load_tracks( + def _load_tracks_graph( run_dir: Path, required: bool = True ) -> nx.DiGraph | None: """Load tracks from file. Currently uses networkx node link data @@ -269,7 +275,10 @@ def _load_gaps(run_dir, required: bool = True) -> list[float]: gaps_file = run_dir / GAPS_FILENAME if gaps_file.is_file(): with open(gaps_file) as f: - gaps = list(map(float, f.read().split(","))) + file_content = f.read() + if file_content == "": + return [] + gaps = list(map(float, file_content.split(","))) return gaps elif required: raise FileNotFoundError(f"No gaps found at {gaps_file}") diff --git a/src/motile_plugin/backend/solve.py b/src/motile_plugin/backend/solve.py index cf5b586..8bdf9d7 100644 --- a/src/motile_plugin/backend/solve.py +++ b/src/motile_plugin/backend/solve.py @@ -14,6 +14,7 @@ get_candidate_graph_from_points_list, graph_to_nx, ) +from motile_toolbox.visualization.napari_utils import assign_tracklet_ids from .solver_params import SolverParams @@ -66,7 +67,7 @@ def solve( solution_graph = solver.get_selected_subgraph(solution=solution) solution_nx_graph = graph_to_nx(solution_graph) - + solution_nx_graph, _ = assign_tracklet_ids(solution_nx_graph) return solution_nx_graph diff --git a/src/motile_plugin/core/__init__.py b/src/motile_plugin/core/__init__.py new file mode 100644 index 0000000..35f4c70 --- /dev/null +++ b/src/motile_plugin/core/__init__.py @@ -0,0 +1,2 @@ +from .tracks import Tracks # noqa +from .node_type import NodeType # noqa diff --git a/src/motile_plugin/core/node_type.py b/src/motile_plugin/core/node_type.py new file mode 100644 index 0000000..4b5f0b2 --- /dev/null +++ b/src/motile_plugin/core/node_type.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class NodeType(Enum): + """Types of nodes in the track graph. Currently used for standardizing + visualization. All nodes are exactly one type. + """ + + SPLIT = "SPLIT" + END = "END" + CONTINUE = "CONTINUE" diff --git a/src/motile_plugin/core/tracks.py b/src/motile_plugin/core/tracks.py new file mode 100644 index 0000000..f5c55d9 --- /dev/null +++ b/src/motile_plugin/core/tracks.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from motile_toolbox.candidate_graph import NodeAttr +from pydantic import BaseModel + +if TYPE_CHECKING: + from typing import Any + + import networkx as nx + import numpy as np + + +class Tracks(BaseModel): + """A set of tracks consisting of a graph and an optional segmentation. + The graph nodes represent detections and must have a time attribute and + position attribute. Edges in the graph represent links across time. + + Attributes: + graph (nx.DiGraph): A graph with nodes representing detections and + and edges representing links across time. Assumed to be "valid" + tracks (e.g., this is not supposed to be a candidate graph), + but the structure is not verified. + segmentation (Optional(np.ndarray)): An optional segmentation that + accompanies the tracking graph. If a segmentation is provided, + it is assumed that the graph has an attribute (default + "seg_id") holding the segmentation id. Defaults to None. + time_attr (str): The attribute in the graph that specifies the time + frame each node is in. + pos_attr (str | tuple[str] | list[str]): The attribute in the graph + that specifies the position of each node. Can be a single attribute + that holds a list, or a list of attribute keys. + + """ + + graph: nx.DiGraph + segmentation: np.ndarray | None = None + time_attr: str = NodeAttr.TIME.value + pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value + scale: list[float] | None = None + # pydantic does not check numpy arrays + model_config = {"arbitrary_types_allowed": True} + + def get_location(self, node: Any, incl_time: bool = False): + """Get the location of a node in the graph. Optionally include the + time frame as the first dimension. Raises an error if the node + is not in the graph. + + Args: + node (Any): The node id in the graph to get the location of. + incl_time (bool, optional): If true, include the time as the + first element of the location array. Defaults to False. + + Returns: + list[float]: A list holding the location. If the position + is stored in a single key, the location could be any number + of dimensions. + """ + data = self.graph.nodes[node] + if isinstance(self.pos_attr, (tuple, list)): + pos = [data[dim] for dim in self.pos_attr] + else: + pos = data[self.pos_attr] + + if incl_time: + pos = [data[self.time_attr], *pos] + + return pos + + def get_time(self, node: Any) -> int: + """Get the time frame of a given node. Raises an error if the node + is not in the graph. + + Args: + node (Any): The node id to get the time frame for + + Returns: + int: The time frame that the node is in + """ + return self.graph.nodes[node][self.time_attr] diff --git a/src/motile_plugin/layers/__init__.py b/src/motile_plugin/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/motile_plugin/layers/track_graph.py b/src/motile_plugin/layers/track_graph.py new file mode 100644 index 0000000..d8410d6 --- /dev/null +++ b/src/motile_plugin/layers/track_graph.py @@ -0,0 +1,72 @@ +import copy + +import napari +import networkx as nx +import numpy as np +from motile_toolbox.visualization import to_napari_tracks_layer +from napari.utils import CyclicLabelColormap + +from motile_plugin.core import Tracks + + +class TrackGraph(napari.layers.Tracks): + """Extended tracks layer that holds the track information and emits and responds + to dynamics visualization signals""" + + def __init__( + self, + viewer: napari.Viewer, + tracks: Tracks, + name: str, + colormap: CyclicLabelColormap, + scale: tuple, + ): + if tracks is None or tracks.graph is None: + graph = nx.DiGraph() + else: + graph = tracks.graph + + track_data, track_props, track_edges = to_napari_tracks_layer( + graph, frame_key=tracks.time_attr, location_key=tracks.pos_attr + ) + + super().__init__( + data=track_data, + graph=track_edges, + properties=track_props, + name=name, + tail_length=3, + color_by="track_id", + scale=scale, + ) + + self.viewer = viewer + self.colormaps_dict["track_id"] = colormap + + self.tracks_layer_graph = copy.deepcopy( + self.graph + ) # for restoring graph later + + def update_track_visibility(self, visible: list[int] | str) -> None: + """Optionally show only the tracks of a current lineage""" + + if visible == "all": + self.track_colors[:, 3] = 1 + self.graph = self.tracks_layer_graph + else: + track_id_mask = np.isin( + self.properties["track_id"], + visible, + ) + self.graph = { + key: self.tracks_layer_graph[key] + for key in visible + if key in self.tracks_layer_graph + } + + self.track_colors[:, 3] = 0 + self.track_colors[track_id_mask, 3] = 1 + if len(self.graph.items()) == 0: + self.display_graph = False # empty dicts to not trigger update (bug?) so disable the graph entirely as a workaround + else: + self.display_graph = True diff --git a/src/motile_plugin/layers/track_labels.py b/src/motile_plugin/layers/track_labels.py new file mode 100644 index 0000000..7b002bc --- /dev/null +++ b/src/motile_plugin/layers/track_labels.py @@ -0,0 +1,128 @@ +import copy +from typing import Dict, List + +import napari +import numpy as np +from napari.utils import CyclicLabelColormap, DirectLabelColormap + +from motile_plugin.core import Tracks + +from ..utils.node_selection import NodeSelectionList + + +def create_selection_label_cmap( + color_dict_rgb: Dict, visible: List[int] | str, highlighted: List[int] +) -> DirectLabelColormap: + """Generates a label colormap with three possible opacity values (0 for invisibible labels, 0.6 for visible labels, and 1 for selected labels)""" + + color_dict_rgb_temp = copy.deepcopy(color_dict_rgb) + if visible == "all": + for key in color_dict_rgb_temp: + if key is not None: + color_dict_rgb_temp[key][-1] = 0.6 # set opacity to 0.6 + else: + for label in visible: + color_dict_rgb_temp[label][-1] = 0.6 # set opacity to 0.6 + + for label in highlighted: + if label != 0: + color_dict_rgb_temp[label][-1] = 1 # set opacity to full + + return DirectLabelColormap(color_dict=color_dict_rgb_temp) + + +class TrackLabels(napari.layers.Labels): + """Extended labels layer that holds the track information and emits and responds to dynamics visualization signals""" + + def __init__( + self, + viewer: napari.Viewer, + data: np.array, + name: str, + colormap: CyclicLabelColormap, + tracks: Tracks, + opacity: float, + selected_nodes: NodeSelectionList, + scale: tuple, + ): + self.nodes = list(tracks.graph.nodes) + props = { + "node_id": self.nodes, + "track_id": [ + data["tracklet_id"] + for _, data in tracks.graph.nodes(data=True) + ], + "t": [tracks.get_time(node) for node in self.nodes], + } + super().__init__( + data=data, + name=name, + opacity=opacity, + colormap=colormap, + properties=props, + scale=scale, + ) + + self.viewer = viewer + self.selected_nodes = selected_nodes + self.tracks = tracks + + self.base_label_color_dict = self.create_label_color_dict( + np.unique(self.properties["track_id"]), colormap=colormap + ) + + @self.mouse_drag_callbacks.append + def click(_, event): + if event.type == "mouse_press": + label = self.get_value( + event.position, + view_direction=event.view_direction, + dims_displayed=event.dims_displayed, + world=True, + ) + + if label is not None and label != 0: + t_values = self.properties["t"] + track_ids = self.properties["track_id"] + index = np.where( + (t_values == event.position[0]) & (track_ids == label) + )[ + 0 + ] # np.where returns a tuple with an array per dimension, here we apply it to a single dimension so take the first element (an array of indices fulfilling condition) + node_id = self.nodes[index[0]] + append = "Shift" in event.modifiers + self.selected_nodes.add(node_id, append) + + def create_label_color_dict( + self, labels: List[int], colormap: CyclicLabelColormap + ) -> Dict: + """Extract the label colors to generate a base colormap, but keep opacity at 0""" + + color_dict_rgb = {None: [0.0, 0.0, 0.0, 0.0]} + + # Iterate over unique labels + for label in labels: + color = colormap.map(label) + color[-1] = ( + 0 # Set opacity to 0 (will be replaced when a label is visible/invisible/selected) + ) + color_dict_rgb[label] = color + + return color_dict_rgb + + def update_label_colormap(self, visible: list[int] | str) -> None: + """Updates the opacity of the label colormap to highlight the selected label and optionally hide cells not belonging to the current lineage""" + + highlighted = [ + self.tracks.graph.nodes[node]["tracklet_id"] + for node in self.selected_nodes + if self.tracks.get_time(node) == self.viewer.dims.current_step[0] + ] + + if self.base_label_color_dict is not None: + colormap = create_selection_label_cmap( + self.base_label_color_dict, + visible=visible, + highlighted=highlighted, + ) + self.colormap = colormap diff --git a/src/motile_plugin/layers/track_points.py b/src/motile_plugin/layers/track_points.py new file mode 100644 index 0000000..5476b94 --- /dev/null +++ b/src/motile_plugin/layers/track_points.py @@ -0,0 +1,121 @@ +import napari +import numpy as np + +from motile_plugin.core import NodeType, Tracks + +from ..utils.node_selection import NodeSelectionList + + +class TrackPoints(napari.layers.Points): + """Extended points layer that holds the track information and emits and responds to dynamics visualization signals""" + + def __init__( + self, + viewer: napari.Viewer, + tracks: Tracks, + name: str, + selected_nodes: NodeSelectionList, + symbolmap: dict[NodeType, str], + colormap: napari.utils.Colormap, + scale: tuple, + ): + self.colormap = colormap + self.symbolmap = symbolmap + + self.nodes = list(tracks.graph.nodes) + self.node_index_dict = dict( + zip(self.nodes, [self.nodes.index(node) for node in self.nodes]) + ) + points = [ + tracks.get_location(node, incl_time=True) for node in self.nodes + ] + track_ids = [ + tracks.graph.nodes[node]["tracklet_id"] for node in self.nodes + ] + colors = [colormap.map(track_id) for track_id in track_ids] + symbols = self.get_symbols(tracks, symbolmap) + + super().__init__( + data=points, + name=name, + symbol=symbols, + face_color=colors, + size=5, + properties={"node_id": self.nodes, "track_id": track_ids}, + border_color=[1, 1, 1, 1], + scale=scale, + ) + + self.viewer = viewer + self.selected_nodes = selected_nodes + + @self.mouse_drag_callbacks.append + def click(layer, event): + if event.type == "mouse_press": + # is the value passed from the click event? + point_index = layer.get_value( + event.position, + view_direction=event.view_direction, + dims_displayed=event.dims_displayed, + world=True, + ) + if point_index is not None: + node_id = self.nodes[point_index] + append = "Shift" in event.modifiers + self.selected_nodes.add(node_id, append) + + # listen to updates in the selected data (from the point selection tool) to update the nodes in self.selected_nodes + self.selected_data.events.items_changed.connect(self._update_selection) + + def _update_selection(self): + """Replaces the list of selected_nodes with the selection provided by the user""" + + selected_points = self.selected_data + self.selected_nodes.reset() + for point in selected_points: + node_id = self.nodes[point] + self.selected_nodes.add(node_id, True) + + def get_symbols( + self, tracks: Tracks, symbolmap: dict[NodeType, str] + ) -> list[str]: + statemap = { + 0: NodeType.END, + 1: NodeType.CONTINUE, + 2: NodeType.SPLIT, + } + symbols = [ + symbolmap[statemap[degree]] + for _, degree in tracks.graph.out_degree + ] + return symbols + + def update_point_outline(self, visible: list[int] | str) -> None: + """Update the outline color of the selected points and visibility according to display mode + + Args: + visible (list[int] | str): A list of track ids, or "all" + """ + # filter out the non-selected tracks if in lineage mode + if visible == "all": + self.shown[:] = True + else: + indices = np.where(np.isin(self.properties["track_id"], visible))[ + 0 + ].tolist() + self.shown[:] = False + self.shown[indices] = True + + # set border color for selected item + self.border_color = [1, 1, 1, 1] + self.size = 5 + for node in self.selected_nodes: + index = self.node_index_dict[node] + self.border_color[index] = ( + 0, + 1, + 1, + 1, + ) + self.size[index] = 7 + self.refresh() diff --git a/src/motile_plugin/napari.yaml b/src/motile_plugin/napari.yaml index 66f6acd..eed51f5 100644 --- a/src/motile_plugin/napari.yaml +++ b/src/motile_plugin/napari.yaml @@ -7,9 +7,13 @@ categories: ["Utilities"] contributions: commands: - id: motile-plugin.motile_widget - python_name: motile_plugin.widgets.motile_widget:MotileWidget + python_name: motile_plugin.widgets.motile.motile_widget:MotileWidget title: "Start the motile widget" short_title: "motile widget" + - id: motile-plugin.tree_widget + python_name: motile_plugin.widgets.tracks_view.tree_widget:TreeWidget + title: "Open the lineage view widget" + short_title: "lineage view" - id: motile-plugin.solve python_name: motile_plugin.backend.solve:solve title: "Run motile tracking (backend only)" @@ -19,6 +23,8 @@ contributions: widgets: - command: motile-plugin.motile_widget display_name: Motile Tracking + - command: motile-plugin.tree_widget + display_name: Lineage View sample_data: - command: motile-plugin.Fluo_N2DL_HeLa key: "Fluo-N2DL-HeLa" diff --git a/src/motile_plugin/utils/node_selection.py b/src/motile_plugin/utils/node_selection.py new file mode 100644 index 0000000..a098d25 --- /dev/null +++ b/src/motile_plugin/utils/node_selection.py @@ -0,0 +1,48 @@ +from psygnal import Signal +from PyQt5.QtCore import QObject + + +class NodeSelectionList(QObject): + """Updates the current selection (0, 1, or 2) of nodes. Sends a signal on every update. + Stores a list of node ids only.""" + + list_updated = Signal() + + def __init__(self): + super().__init__() + self._list = [] + + def add(self, item, append: bool | None = False): + """Append or replace an item to the list, depending on the number of items present and the keyboard modifiers used. Emit update signal""" + + # first check if this node was already present, if so, remove it. + if item in self._list: + self._list.remove(item) + + # single selection plus shift modifier: append to list to have two items in it + elif append: + self._list.append(item) + + # replace item in list + else: + self._list = [] + self._list.append(item) + + # emit update signal + self.list_updated.emit() + + def flip(self): + """Change the order of the items in the list""" + if len(self) == 2: + self._list = [self._list[1], self._list[0]] + + def reset(self): + """Empty list and emit update signal""" + self._list = [] + self.list_updated.emit() + + def __getitem__(self, index): + return self._list[index] + + def __len__(self): + return len(self._list) diff --git a/src/motile_plugin/utils/tree_widget_utils.py b/src/motile_plugin/utils/tree_widget_utils.py new file mode 100644 index 0000000..a11d9f9 --- /dev/null +++ b/src/motile_plugin/utils/tree_widget_utils.py @@ -0,0 +1,171 @@ +from typing import Dict, List + +import napari.layers +import networkx as nx +import pandas as pd +from motile_plugin.core import NodeType, Tracks + + +def extract_sorted_tracks( + tracks: Tracks, + colormap: napari.utils.CyclicLabelColormap, +) -> pd.DataFrame | None: + """ + Extract the information of individual tracks required for constructing the pyqtgraph plot. Follows the same logic as the relabel_segmentation + function from the Motile toolbox. + + Args: + tracks (motile_plugin.core.Tracks): A tracks object containing a graph + to be converted into a dataframe. + colormap (napari.utils.CyclicLabelColormap): The colormap to use to + extract the color of each node from the track ID + + Returns: + pd.DataFrame | None: data frame with all the information needed to + construct the pyqtgraph plot. Columns are: 't', 'node_id', 'track_id', + 'color', 'x', 'y', ('z'), 'index', 'parent_id', 'parent_track_id', + 'state', 'symbol', and 'x_axis_pos' + """ + if tracks is None or tracks.graph is None: + return None + + solution_nx_graph = tracks.graph + + track_list = [] + id_counter = 1 + parent_mapping = [] + + # Identify parent nodes (nodes with more than one child) + parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1] + end_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d == 0] + + # Make a copy of the graph and remove outgoing edges from parent nodes to isolate tracks + soln_copy = solution_nx_graph.copy() + for parent_node in parent_nodes: + out_edges = solution_nx_graph.out_edges(parent_node) + soln_copy.remove_edges_from(out_edges) + + # Process each weakly connected component as a separate track + for node_set in nx.weakly_connected_components(soln_copy): + # Sort nodes in each weakly connected component by their time attribute to ensure correct order + sorted_nodes = sorted( + node_set, + key=lambda node: tracks.get_time(node), + ) + + parent_track_id = None + for node in sorted_nodes: + pos = tracks.get_location(node) + if node in parent_nodes: + state = NodeType.SPLIT + symbol = "t1" + elif node in end_nodes: + state = NodeType.END + symbol = "x" + else: + state = NodeType.CONTINUE + symbol = "o" + + track_id = solution_nx_graph.nodes[node]["tracklet_id"] + track_dict = { + "t": tracks.get_time(node), + "node_id": node, + "track_id": track_id, + "color": colormap.map(track_id) * 255, + "x": pos[-1], + "y": pos[-2], + "parent_id": 0, + "parent_track_id": 0, + "state": state, + "symbol": symbol, + } + + if len(pos) == 3: + track_dict["z"] = pos[0] + + # Determine parent_id and parent_track_id + predecessors = list(solution_nx_graph.predecessors(node)) + if predecessors: + parent_id = predecessors[ + 0 + ] # There should be only one predecessor in a lineage tree + track_dict["parent_id"] = parent_id + + if parent_track_id is None: + parent_track_id = solution_nx_graph.nodes[parent_id][ + "tracklet_id" + ] + track_dict["parent_track_id"] = parent_track_id + + else: + parent_track_id = 0 + track_dict["parent_id"] = 0 + track_dict["parent_track_id"] = parent_track_id + + track_list.append(track_dict) + + parent_mapping.append( + {"track_id": id_counter, "parent_track_id": parent_track_id} + ) + id_counter += 1 + + x_axis_order = sort_track_ids(parent_mapping) + + for node in track_list: + node["x_axis_pos"] = x_axis_order.index(node["track_id"]) + + return pd.DataFrame(track_list) + + +def sort_track_ids(track_list: List[Dict]) -> List[Dict]: + """ + Sort track IDs such to maintain left-first order in the tree formed by parent-child relationships. + Used to determine the x-axis order of the tree plot. + + Args: + track_list (list): List of dictionaries with 'track_id' and 'parent_track_id'. + + Returns: + list: Ordered list of track IDs for the x-axis. + """ + + roots = [ + node["track_id"] for node in track_list if node["parent_track_id"] == 0 + ] + x_axis_order = list(roots) + + # Find the children of each of the starting points, and work down the tree. + while len(roots) > 0: + children_list = [] + for track_id in roots: + children = [ + node["track_id"] + for node in track_list + if node["parent_track_id"] == track_id + ] + for i, child in enumerate(children): + [children_list.append(child)] + x_axis_order.insert(x_axis_order.index(track_id) + i, child) + roots = children_list + + return x_axis_order + + +def extract_lineage_tree(graph: nx.DiGraph, node_id: str) -> List[str]: + """Extract the entire lineage tree including horizontal relations for a given node""" + + # go up the tree to identify the root node + root_node = node_id + while True: + predecessors = list(graph.predecessors(root_node)) + if not predecessors: + break + root_node = predecessors[0] + + # extract all descendants to get the full tree + nodes = nx.descendants(graph, root_node) + + # include root + nodes.add(root_node) + + return list(nodes) diff --git a/src/motile_plugin/widgets/__init__.py b/src/motile_plugin/widgets/__init__.py index e69de29..a102381 100644 --- a/src/motile_plugin/widgets/__init__.py +++ b/src/motile_plugin/widgets/__init__.py @@ -0,0 +1,3 @@ +from .motile.motile_widget import MotileWidget # noqa +from .tracks_view.tree_widget import TreeWidget # noqa +from .tracks_view.tracks_viewer import TracksViewer # noqa diff --git a/src/motile_plugin/widgets/motile/__init__.py b/src/motile_plugin/widgets/motile/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/motile_plugin/widgets/motile_widget.py b/src/motile_plugin/widgets/motile/motile_widget.py similarity index 65% rename from src/motile_plugin/widgets/motile_widget.py rename to src/motile_plugin/widgets/motile/motile_widget.py index 5ad029e..e359c24 100644 --- a/src/motile_plugin/widgets/motile_widget.py +++ b/src/motile_plugin/widgets/motile/motile_widget.py @@ -1,12 +1,13 @@ import logging -from motile_toolbox.utils import relabel_segmentation -from motile_toolbox.visualization import to_napari_tracks_layer +import networkx as nx +import numpy as np +from motile_toolbox.candidate_graph import NodeAttr from napari import Viewer -from napari.layers import Labels, Tracks -from qtpy.QtCore import Signal +from psygnal import Signal from qtpy.QtWidgets import ( QLabel, + QScrollArea, QVBoxLayout, QWidget, ) @@ -14,6 +15,8 @@ from motile_plugin.backend.motile_run import MotileRun from motile_plugin.backend.solve import solve +from motile_plugin.core import Tracks +from motile_plugin.widgets.tracks_view.tracks_viewer import TracksViewer from .run_editor import RunEditor from .run_viewer import RunViewer @@ -22,22 +25,24 @@ logger = logging.getLogger(__name__) -class MotileWidget(QWidget): - """The main widget for the motile napari plugin. Coordinates sub-widgets - and calls the back-end motile solver. +class MotileWidget(QScrollArea): + """A widget that controls the backend components of the motile napari plugin. + Recieves user input about solver parameters, runs motile, and passes + results to the TrackingViewController. """ # A signal for passing events from the motile solver to the run view widget # To provide updates on progress of the solver solver_update = Signal() + view_tracks = Signal(Tracks, str) + remove_layers = Signal() def __init__(self, viewer: Viewer): super().__init__() self.viewer: Viewer = viewer - - # Declare napari layers for displaying outputs (managed by the widget) - self.output_seg_layer: Labels | None = None - self.tracks_layer: Tracks | None = None + tracks_viewer = TracksViewer.get_instance(self.viewer) + self.view_tracks.connect(tracks_viewer.update_tracks) + self.remove_layers.connect(tracks_viewer.remove_napari_layers) # Create sub-widgets and connect signals self.edit_run_widget = RunEditor(self.viewer) @@ -49,7 +54,7 @@ def __init__(self, viewer: Viewer): self.solver_update.connect(self.view_run_widget.solver_event_update) self.run_list_widget = RunsList() - self.run_list_widget.view_run.connect(self.view_run_napari) + self.run_list_widget.view_run.connect(self.view_run) # Create main layout main_layout = QVBoxLayout() @@ -57,53 +62,12 @@ def __init__(self, viewer: Viewer): main_layout.addWidget(self.view_run_widget) main_layout.addWidget(self.edit_run_widget) main_layout.addWidget(self.run_list_widget) - self.setLayout(main_layout) - - def remove_napari_layers(self) -> None: - """Remove the currently stored layers from the napari viewer, if present""" - if ( - self.output_seg_layer - and self.output_seg_layer in self.viewer.layers - ): - self.viewer.layers.remove(self.output_seg_layer) - if self.tracks_layer and self.tracks_layer in self.viewer.layers: - self.viewer.layers.remove(self.tracks_layer) - - def update_napari_layers(self, run: MotileRun) -> None: - """Remove the old napari layers and update them according to the run output. - Will create new segmentation and tracks layers and add them to the viewer. - - Args: - run (MotileRun): The run outputs to visualize in napari. - """ - # Remove old layers if necessary - self.remove_napari_layers() - - # Create new layers - if run.output_segmentation is not None: - self.output_seg_layer = Labels( - run.output_segmentation[:, 0], name=run.run_name + "_seg" - ) - self.viewer.add_layer(self.output_seg_layer) - else: - self.output_seg_layer = None - - if run.tracks is None or run.tracks.number_of_nodes() == 0: - self.tracks_layer = None - else: - track_data, track_props, track_edges = to_napari_tracks_layer( - run.tracks - ) - self.tracks_layer = Tracks( - track_data, - properties=track_props, - graph=track_edges, - name=run.run_name + "_tracks", - tail_length=3, - ) - self.viewer.add_layer(self.tracks_layer) + main_widget = QWidget() + main_widget.setLayout(main_layout) + self.setWidget(main_widget) + self.setWidgetResizable(True) - def view_run_napari(self, run: MotileRun) -> None: + def view_run(self, run: MotileRun) -> None: """Populates the run viewer and the napari layers with the output of the provided run. @@ -113,7 +77,7 @@ def view_run_napari(self, run: MotileRun) -> None: self.view_run_widget.update_run(run) self.edit_run_widget.hide() self.view_run_widget.show() - self.update_napari_layers(run) + self.view_tracks.emit(run.tracks, run.run_name) def edit_run(self, run: MotileRun | None): """Create or edit a new run in the run editor. Also removes solution layers @@ -127,7 +91,7 @@ def edit_run(self, run: MotileRun | None): if run: self.edit_run_widget.new_run(run) self.run_list_widget.runs_list.clearSelection() - self.remove_napari_layers() + self.remove_layers.emit() def _generate_tracks(self, run: MotileRun) -> None: """Called when we start solving a new run. Switches from run editor to run viewer @@ -142,6 +106,47 @@ def _generate_tracks(self, run: MotileRun) -> None: worker.returned.connect(self._on_solve_complete) worker.start() + def relabel_segmentation( + self, + solution_nx_graph: nx.DiGraph, + segmentation: np.ndarray, + ) -> np.ndarray: + """Relabel a segmentation based on tracking results so that nodes in same + track share the same id. IDs do change at division. + + Args: + solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use + for relabeling. Nodes not in graph will be removed from seg. Original + segmentation ids and hypothesis ids have to be stored in the graph so we + can map them back. + segmentation (np.ndarray): Original (potentially multi-hypothesis) + segmentation with dimensions (t,h,[z],y,x), where h is 1 for single + input segmentation. + + Returns: + np.ndarray: Relabeled segmentation array where nodes in same track share same + id with shape (t,1,[z],y,x) + """ + output_shape = (segmentation.shape[0], 1, *segmentation.shape[2:]) + tracked_masks = np.zeros_like(segmentation, shape=output_shape) + for node, _data in solution_nx_graph.nodes(data=True): + time_frame = solution_nx_graph.nodes[node][NodeAttr.TIME.value] + previous_seg_id = solution_nx_graph.nodes[node][ + NodeAttr.SEG_ID.value + ] + tracklet_id = solution_nx_graph.nodes[node]["tracklet_id"] + if NodeAttr.SEG_HYPO.value in solution_nx_graph.nodes[node]: + hypothesis_id = solution_nx_graph.nodes[node][ + NodeAttr.SEG_HYPO.value + ] + else: + hypothesis_id = 0 + previous_seg_mask = ( + segmentation[time_frame, hypothesis_id] == previous_seg_id + ) + tracked_masks[time_frame, 0][previous_seg_mask] = tracklet_id + return tracked_masks + @thread_worker def solve_with_motile(self, run: MotileRun) -> MotileRun: """Runs the solver and relabels the segmentation to match @@ -163,15 +168,20 @@ def solve_with_motile(self, run: MotileRun) -> MotileRun: input_data = run.input_points else: raise ValueError("Must have one of input segmentation or points") - run.tracks = solve( + graph = solve( run.solver_params, input_data, lambda event_data: self._on_solver_event(run, event_data), ) if run.input_segmentation is not None: - run.output_segmentation = relabel_segmentation( - run.tracks, run.input_segmentation + output_segmentation = self.relabel_segmentation( + graph, run.input_segmentation ) + else: + output_segmentation = None + run.tracks = Tracks( + graph=graph, segmentation=output_segmentation, scale=run.scale + ) return run def _on_solver_event(self, run: MotileRun, event_data: dict) -> None: @@ -210,7 +220,7 @@ def _on_solve_complete(self, run: MotileRun) -> None: """ run.status = "done" self.solver_update.emit() - self.view_run_napari(run) + self.view_run(run) def _title_widget(self) -> QWidget: """Create the title and intro paragraph widget, with links to docs diff --git a/src/motile_plugin/widgets/param_values.py b/src/motile_plugin/widgets/motile/param_values.py similarity index 100% rename from src/motile_plugin/widgets/param_values.py rename to src/motile_plugin/widgets/motile/param_values.py diff --git a/src/motile_plugin/widgets/params_editor.py b/src/motile_plugin/widgets/motile/params_editor.py similarity index 100% rename from src/motile_plugin/widgets/params_editor.py rename to src/motile_plugin/widgets/motile/params_editor.py diff --git a/src/motile_plugin/widgets/params_viewer.py b/src/motile_plugin/widgets/motile/params_viewer.py similarity index 100% rename from src/motile_plugin/widgets/params_viewer.py rename to src/motile_plugin/widgets/motile/params_viewer.py diff --git a/src/motile_plugin/widgets/run_editor.py b/src/motile_plugin/widgets/motile/run_editor.py similarity index 99% rename from src/motile_plugin/widgets/run_editor.py rename to src/motile_plugin/widgets/motile/run_editor.py index 8ba302c..93c1054 100644 --- a/src/motile_plugin/widgets/run_editor.py +++ b/src/motile_plugin/widgets/motile/run_editor.py @@ -200,6 +200,7 @@ def get_run(self) -> MotileRun | None: input_segmentation=input_seg, input_points=input_points, time=datetime.now(), + scale=input_layer.scale, ) def emit_run(self) -> None: diff --git a/src/motile_plugin/widgets/run_viewer.py b/src/motile_plugin/widgets/motile/run_viewer.py similarity index 100% rename from src/motile_plugin/widgets/run_viewer.py rename to src/motile_plugin/widgets/motile/run_viewer.py diff --git a/src/motile_plugin/widgets/runs_list.py b/src/motile_plugin/widgets/motile/runs_list.py similarity index 100% rename from src/motile_plugin/widgets/runs_list.py rename to src/motile_plugin/widgets/motile/runs_list.py diff --git a/src/motile_plugin/widgets/tracks_view/__init__.py b/src/motile_plugin/widgets/tracks_view/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/motile_plugin/widgets/tracks_view/navigation_widget.py b/src/motile_plugin/widgets/tracks_view/navigation_widget.py new file mode 100644 index 0000000..fe2f44b --- /dev/null +++ b/src/motile_plugin/widgets/tracks_view/navigation_widget.py @@ -0,0 +1,173 @@ +import pandas as pd +from qtpy.QtWidgets import ( + QGroupBox, + QHBoxLayout, + QPushButton, + QWidget, +) + +from motile_plugin.utils.node_selection import NodeSelectionList + + +class NavigationWidget(QWidget): + def __init__( + self, + track_df: pd.DataFrame, + lineage_df: pd.DataFrame, + view_direction: str, + selected_nodes: NodeSelectionList, + ): + """Widget for controlling navigation in the tree widget + + Args: + track_df (pd.DataFrame): The dataframe holding the track information + view_direction (str): The view direction of the tree widget. Options: "vertical", "horizontal". + selected_nodes (NodeSelectionList): The list of selected nodes. + """ + + super().__init__() + self.track_df = track_df + self.lineage_df = lineage_df + self.view_direction = view_direction + self.selected_nodes = selected_nodes + + navigation_box = QGroupBox("Navigation [\u2b05 \u27a1 \u2b06 \u2b07]") + navigation_layout = QHBoxLayout() + left_button = QPushButton("\u2b05") + right_button = QPushButton("\u27a1") + up_button = QPushButton("\u2b06") + down_button = QPushButton("\u2b07") + + left_button.clicked.connect(lambda: self.move("left")) + right_button.clicked.connect(lambda: self.move("right")) + up_button.clicked.connect(lambda: self.move("up")) + down_button.clicked.connect(lambda: self.move("down")) + + navigation_layout.addWidget(left_button) + navigation_layout.addWidget(right_button) + navigation_layout.addWidget(up_button) + navigation_layout.addWidget(down_button) + navigation_box.setLayout(navigation_layout) + navigation_box.setMaximumWidth(250) + + layout = QHBoxLayout() + layout.addWidget(navigation_box) + + self.setLayout(layout) + + def move(self, direction: str) -> None: + """Move in the given direction on the tree view. Will select the next + node in that direction, based on the orientation of the widget. + + Args: + direction (str): The direction to move. Options: "up", "down", + "left", "right" + """ + if len(self.selected_nodes) == 0: + return + node_id = self.selected_nodes[0] + + if direction == "left": + if self.view_direction == "horizontal": + next_node = self.get_predecessor(node_id) + else: + next_node = self.get_next_track_node( + self.track_df, node_id, forward=False + ) + elif direction == "right": + if self.view_direction == "horizontal": + next_node = self.get_successor(node_id) + else: + next_node = self.get_next_track_node(self.track_df, node_id) + elif direction == "up": + if self.view_direction == "horizontal": + next_node = self.get_next_track_node(self.lineage_df, node_id) + if next_node is None: + next_node = self.get_next_track_node( + self.track_df, node_id + ) + else: + next_node = self.get_predecessor(node_id) + elif direction == "down": + if self.view_direction == "horizontal": + # try navigation within the current lineage_df first + next_node = self.get_next_track_node( + self.lineage_df, node_id, forward=False + ) + # if not found, look in the whole dataframe + # to enable jumping to the next node outside the current tree view content + if next_node is None: + next_node = self.get_next_track_node( + self.track_df, node_id, forward=False + ) + else: + next_node = self.get_successor(node_id) + else: + raise ValueError( + f"Direction must be one of 'left', 'right', 'up', 'down', got {direction}" + ) + if next_node is not None: + self.selected_nodes.add(next_node) + + def get_next_track_node( + self, df: pd.DataFrame, node_id: str, forward=True + ) -> str | None: + """Get the node at the same time point in an adjacent track. + + Args: + df (pd.DataFrame): the dataframe to be used. It can either be the + full track_df, or the subset lineage_df + node_id (str): The current node id to get the next from + forward (bool, optional): If true, pick the next track (right/down). + Otherwise, pick the previous track (left/up). Defaults to True. + """ + x_axis_pos = df.loc[df["node_id"] == node_id, "x_axis_pos"].values[0] + t = df.loc[df["node_id"] == node_id, "t"].values[0] + if forward: + neighbors = df.loc[ + (df["x_axis_pos"] > x_axis_pos) & (df["t"] == t) + ] + else: + neighbors = df.loc[ + (df["x_axis_pos"] < x_axis_pos) & (df["t"] == t) + ] + if not neighbors.empty: + # Find the closest index label + closest_index_label = ( + (neighbors["x_axis_pos"] - x_axis_pos).abs().idxmin() + ) + neighbor = neighbors.loc[closest_index_label, "node_id"] + return neighbor + + def get_predecessor(self, node_id: str) -> str | None: + """Get the predecessor node of the given node_id + + Args: + node_id (str): the node id to get the predecessor of + + Returns: + str | None: THe node id of the predecessor, or none if no predecessor + is found + """ + parent_id = self.track_df.loc[ + self.track_df["node_id"] == node_id, "parent_id" + ].values[0] + parent_row = self.track_df.loc[self.track_df["node_id"] == parent_id] + if not parent_row.empty: + return parent_row["node_id"].values[0] + + def get_successor(self, node_id: str) -> str | None: + """Get the successor node of the given node_id. If there are two children, + picks one arbitrarily. + + Args: + node_id (str): the node id to get the successor of + + Returns: + str | None: THe node id of the successor, or none if no successor + is found + """ + children = self.track_df.loc[self.track_df["parent_id"] == node_id] + if not children.empty: + child = children.to_dict("records")[0] + return child["node_id"] diff --git a/src/motile_plugin/widgets/tracks_view/tracks_viewer.py b/src/motile_plugin/widgets/tracks_view/tracks_viewer.py new file mode 100644 index 0000000..554201b --- /dev/null +++ b/src/motile_plugin/widgets/tracks_view/tracks_viewer.py @@ -0,0 +1,250 @@ +from dataclasses import dataclass + +import napari +from psygnal import Signal + +from motile_plugin.core import NodeType, Tracks +from motile_plugin.layers.track_graph import TrackGraph +from motile_plugin.layers.track_labels import TrackLabels +from motile_plugin.layers.track_points import TrackPoints +from motile_plugin.utils.node_selection import NodeSelectionList +from motile_plugin.utils.tree_widget_utils import ( + extract_lineage_tree, +) + + +@dataclass +class TracksLayerGroup: + tracks_layer: TrackGraph | None = None + seg_layer: TrackLabels | None = None + points_layer: TrackPoints | None = None + + +class TracksViewer: + """Purposes of the TracksViewer: + - Emit signals that all widgets should use to update selection or update + the currently displayed Tracks object + - Storing the currently displayed tracks + - Store shared rendering information like colormaps (or symbol maps) + - Interacting with the napari.Viewer by adding and removing layers + """ + + tracks_updated = Signal() + + @classmethod + def get_instance(cls, viewer=None): + if not hasattr(cls, "_instance"): + print("Making new tracking view controller") + if viewer is None: + raise ValueError("Make a viewer first please!") + cls._instance = TracksViewer(viewer) + return cls._instance + + def __init__( + self, + viewer: napari.viewer, + ): + self.viewer = viewer + # TODO: separate and document keybinds + self.viewer.bind_key("t")(self.toggle_display_mode) + + self.selected_nodes = NodeSelectionList() + self.selected_nodes.list_updated.connect(self.update_selection) + + self.tracking_layers = TracksLayerGroup() + self.tracks = None + + self.colormap = napari.utils.colormaps.label_colormap( + 49, + seed=0.5, + background_value=0, + ) + + self.symbolmap: dict[NodeType, str] = { + NodeType.END: "x", + NodeType.CONTINUE: "disc", + NodeType.SPLIT: "triangle_up", + } + self.mode = "all" + + def remove_napari_layer(self, layer: napari.layers.Layer | None) -> None: + """Remove a layer from the napari viewer, if present""" + if layer and layer in self.viewer.layers: + self.viewer.layers.remove(layer) + + def remove_napari_layers(self) -> None: + """Remove all tracking layers from the viewer""" + self.remove_napari_layer(self.tracking_layers.tracks_layer) + self.remove_napari_layer(self.tracking_layers.seg_layer) + self.remove_napari_layer(self.tracking_layers.points_layer) + + def add_napari_layers(self) -> None: + """Add new tracking layers to the viewer""" + if self.tracking_layers.tracks_layer is not None: + self.viewer.add_layer(self.tracking_layers.tracks_layer) + if self.tracking_layers.seg_layer is not None: + self.viewer.add_layer(self.tracking_layers.seg_layer) + if self.tracking_layers.points_layer is not None: + self.viewer.add_layer(self.tracking_layers.points_layer) + + def update_tracks(self, tracks: Tracks, name: str) -> None: + """Stop viewing a previous set of tracks and replace it with a new one. + Will create new segmentation and tracks layers and add them to the viewer. + + Args: + tracks (motile_plugin.core.Tracks): The tracks to visualize in napari. + name (str): The name of the tracks to display in the layer names + """ + self.selected_nodes._list = [] + self.tracks = tracks + # Remove old layers if necessary + self.remove_napari_layers() + + # deactivate the input labels layer + for layer in self.viewer.layers: + if isinstance(layer, napari.layers.Labels): + layer.visible = False + + # Create new layers + if tracks is not None and tracks.segmentation is not None: + self.tracking_layers.seg_layer = TrackLabels( + viewer=self.viewer, + data=tracks.segmentation[:, 0], + name=name + "_seg", + colormap=self.colormap, + tracks=self.tracks, + opacity=0.9, + selected_nodes=self.selected_nodes, + scale=self.tracks.scale, + ) + + else: + self.tracking_layers.seg_layer = None + + if ( + tracks is None + or tracks.graph is None + or tracks.graph.number_of_nodes() == 0 + ): + self.tracking_layers.tracks_layer = None + self.tracking_layers.points_layer = None + else: + self.tracking_layers.tracks_layer = TrackGraph( + viewer=self.viewer, + tracks=tracks, + name=name + "_tracks", + colormap=self.colormap, + scale=self.tracks.scale, + ) + self.tracking_layers.points_layer = TrackPoints( + viewer=self.viewer, + tracks=tracks, + name=name + "_points", + selected_nodes=self.selected_nodes, + symbolmap=self.symbolmap, + colormap=self.colormap, + scale=self.tracks.scale, + ) + + self.tracks_updated.emit() + self.add_napari_layers() + self.set_display_mode("all") + + def toggle_display_mode(self, event=None) -> None: + """Toggle the display mode between available options""" + + if self.mode == "lineage": + self.set_display_mode("all") + else: + self.set_display_mode("lineage") + + def set_display_mode(self, mode: str) -> None: + """Update the display mode and call to update colormaps for points, labels, and tracks""" + + # toggle between 'all' and 'lineage' + if mode == "lineage": + self.mode = "lineage" + self.viewer.text_overlay.text = "Toggle Display [T]\n Lineage" + else: + self.mode = "all" + self.viewer.text_overlay.text = "Toggle Display [T]\n All" + + self.viewer.text_overlay.visible = True + visible = self.filter_visible_nodes() + if self.tracking_layers.seg_layer is not None: + self.tracking_layers.seg_layer.update_label_colormap(visible) + if self.tracking_layers.points_layer is not None: + self.tracking_layers.points_layer.update_point_outline(visible) + if self.tracking_layers.tracks_layer is not None: + self.tracking_layers.tracks_layer.update_track_visibility(visible) + + def filter_visible_nodes(self) -> list[int]: + """Construct a list of track_ids that should be displayed""" + if self.mode == "lineage": + visible = [] + for node in self.selected_nodes: + visible += extract_lineage_tree(self.tracks.graph, node) + if self.tracks is None or self.tracks.graph is None: + return [] + return list( + { + self.tracks.graph.nodes[node]["tracklet_id"] + for node in visible + } + ) + else: + return "all" + + def update_selection(self) -> None: + """Sets the view and triggers visualization updates in other components""" + + self.set_napari_view() + visible = self.filter_visible_nodes() + if self.tracking_layers.seg_layer is not None: + self.tracking_layers.seg_layer.update_label_colormap(visible) + self.tracking_layers.points_layer.update_point_outline(visible) + self.tracking_layers.tracks_layer.update_track_visibility(visible) + + def set_napari_view(self) -> None: + """Adjust the current_step of the viewer to jump to the last item of the selected_nodes list""" + if len(self.selected_nodes) > 0: + node = self.selected_nodes[-1] + location = self.tracks.get_location(node, incl_time=True) + assert ( + len(location) == self.viewer.dims.ndim + ), f"Location {location} does not match viewer number of dims {self.viewer.dims.ndim}" + + step = list(self.viewer.dims.current_step) + for dim in self.viewer.dims.not_displayed: + step[dim] = int(location[dim] + 0.5) + + self.viewer.dims.current_step = step + + # check whether the new coordinates are inside or outside the field of view, then adjust the camera if needed + example_layer = self.tracking_layers.points_layer + corner_pixels = example_layer.corner_pixels + + # check which dimensions are shown, the first dimension is displayed on the x axis, and the second on the y_axis + dims_displayed = self.viewer.dims.displayed + x_dim = dims_displayed[-1] + y_dim = dims_displayed[-2] + + # find corner pixels for the displayed axes + _min_x = corner_pixels[0][x_dim] + _max_x = corner_pixels[1][x_dim] + _min_y = corner_pixels[0][y_dim] + _max_y = corner_pixels[1][y_dim] + + # check whether the node location falls within the corner pixel range + if not ( + (location[x_dim] > _min_x and location[x_dim] < _max_x) + and (location[y_dim] > _min_y and location[y_dim] < _max_y) + ): + camera_center = self.viewer.camera.center + + # set the center y and x to the center of the node, by using the index of the currently displayed dimensions + self.viewer.camera.center = ( + camera_center[0], + location[y_dim] * self.tracks.scale[y_dim], + location[x_dim] * self.tracks.scale[x_dim], + ) diff --git a/src/motile_plugin/widgets/tracks_view/tree_view_mode_widget.py b/src/motile_plugin/widgets/tracks_view/tree_view_mode_widget.py new file mode 100644 index 0000000..b410eaa --- /dev/null +++ b/src/motile_plugin/widgets/tracks_view/tree_view_mode_widget.py @@ -0,0 +1,58 @@ +from psygnal import Signal +from qtpy.QtWidgets import ( + QButtonGroup, + QGroupBox, + QHBoxLayout, + QRadioButton, + QVBoxLayout, + QWidget, +) + + +class TreeViewModeWidget(QWidget): + """Widget to switch between viewing all nodes versus nodes of one or more lineages in the tree widget""" + + change_mode = Signal(str) + + def __init__(self): + super().__init__() + + self.mode = "all" + + display_box = QGroupBox("Display [L]") + display_layout = QHBoxLayout() + button_group = QButtonGroup() + self.show_all_radio = QRadioButton("All cells") + self.show_all_radio.setChecked(True) + self.show_all_radio.clicked.connect(lambda: self._set_mode("all")) + self.show_lineage_radio = QRadioButton("Current lineage(s)") + self.show_lineage_radio.clicked.connect( + lambda: self._set_mode("lineage") + ) + button_group.addButton(self.show_all_radio) + button_group.addButton(self.show_lineage_radio) + display_layout.addWidget(self.show_all_radio) + display_layout.addWidget(self.show_lineage_radio) + display_box.setLayout(display_layout) + display_box.setMaximumWidth(250) + + layout = QVBoxLayout() + layout.addWidget(display_box) + + self.setLayout(layout) + + def _toggle_display_mode(self, event=None) -> None: + """Toggle display mode""" + + if self.mode == "lineage": + self._set_mode("all") + self.show_all_radio.setChecked(True) + else: + self._set_mode("lineage") + self.show_lineage_radio.setChecked(True) + + def _set_mode(self, mode: str): + """Emit signal to change the display mode""" + + self.mode = mode + self.change_mode.emit(mode) diff --git a/src/motile_plugin/widgets/tracks_view/tree_widget.py b/src/motile_plugin/widgets/tracks_view/tree_widget.py new file mode 100644 index 0000000..94cf330 --- /dev/null +++ b/src/motile_plugin/widgets/tracks_view/tree_widget.py @@ -0,0 +1,443 @@ +from typing import Any + +import napari +import numpy as np +import pandas as pd +import pyqtgraph as pg +from psygnal import Signal +from pyqtgraph.Qt import QtCore +from qtpy.QtCore import Qt +from qtpy.QtGui import QColor, QKeyEvent, QMouseEvent +from qtpy.QtWidgets import ( + QHBoxLayout, + QVBoxLayout, + QWidget, +) + +from motile_plugin.utils.tree_widget_utils import ( + extract_lineage_tree, + extract_sorted_tracks, +) +from motile_plugin.widgets.tracks_view.tracks_viewer import ( + TracksViewer, +) + +from .navigation_widget import NavigationWidget +from .tree_view_mode_widget import TreeViewModeWidget + + +class CustomViewBox(pg.ViewBox): + def __init__(self, *args, **kwds): + kwds["enableMenu"] = False + pg.ViewBox.__init__(self, *args, **kwds) + # self.setMouseMode(self.RectMode) + + ## reimplement right-click to zoom out + def mouseClickEvent(self, ev): + if ev.button() == QtCore.Qt.MouseButton.RightButton: + self.autoRange() + + ## reimplement mouseDragEvent to disable continuous axis zoom + def mouseDragEvent(self, ev, axis=None): + if ev.modifiers() == Qt.ShiftModifier: + # If Shift is pressed, enable rectangular zoom mode + self.setMouseMode(self.RectMode) + else: + # Otherwise, disable rectangular zoom mode + self.setMouseMode(self.PanMode) + + if ( + axis is not None + and ev.button() == QtCore.Qt.MouseButton.RightButton + ): + ev.ignore() + else: + pg.ViewBox.mouseDragEvent(self, ev, axis=axis) + + +class TreePlot(pg.PlotWidget): + node_clicked = Signal(Any, bool) # node_id, append + + def __init__(self) -> pg.PlotWidget: + """Construct the pyqtgraph treewidget. This is the actual canvas + on which the tree view is drawn. + """ + super().__init__(viewBox=CustomViewBox()) + self.setFocusPolicy(Qt.StrongFocus) + self.setTitle("Lineage Tree") + + self._pos = [] + self.adj = [] + self.symbolBrush = [] + self.symbols = [] + self.pen = [] + self.outline_pen = [] + self.node_ids = [] + self.sizes = [] + + self.view_direction = None + self.g = pg.GraphItem() + self.g.scatter.sigClicked.connect(self._on_click) + self.addItem(self.g) + self.set_view("vertical") + + def update( + self, + track_df: pd.DataFrame, + view_direction: str, + selected_nodes: list[Any], + ): + """Update the entire view, including the data, view direction, and + selected nodes + + Args: + track_df (pd.DataFrame): The dataframe containing the graph data + view_direction (str): The view direction + selected_nodes (list[Any]): The currently selected nodes to be highlighted + """ + self.set_data(track_df) + self.set_view(view_direction) + self._update_viewed_data() # this can be expensive + self.set_selection(selected_nodes) + + def set_view(self, view_direction: str): + """Set the view direction, saving the new value as an attribute and + changing the axes labels. Shortcuts if the view direction is already + correct. Does not actually update the rendered graph (need to call + _update_viewed_data). + + Args: + view_direction (str): "horizontal" or "vertical" + """ + if view_direction == self.view_direction: + return + self.view_direction = view_direction + if view_direction == "vertical": + self.setLabel("left", text="Time Point") + self.getAxis("left").setStyle(showValues=True) + self.getAxis("bottom").setStyle(showValues=False) + self.invertY(True) # to show tracks from top to bottom + elif view_direction == "horizontal": + self.setLabel("bottom", text="Time Point") + self.getAxis("bottom").setStyle(showValues=True) + self.setLabel("left", text="") + self.getAxis("left").setStyle(showValues=False) + self.invertY(False) + + def _on_click(self, _, points: np.ndarray, ev: QMouseEvent) -> None: + """Adds the selected point to the selected_nodes list. Called when + the user clicks on the TreeWidget to select nodes. + + Args: + points (np.ndarray): _description_ + ev (QMouseEvent): _description_ + """ + + modifiers = ev.modifiers() + node_id = points[0].data() + append = Qt.ShiftModifier == modifiers + self.node_clicked.emit(node_id, append) + + def set_data(self, track_df: pd.DataFrame) -> None: + """Updates the stored pyqtgraph content based on the given dataframe. + Does not render the new information (need to call _update_viewed_data). + + Args: + track_df (pd.DataFrame): The tracks df to compute the pyqtgraph + content for. Can be all lineages or any subset of them. + """ + self.track_df = track_df + self._create_pyqtgraph_content(track_df) + + def _update_viewed_data(self): + self.g.scatter.setPen( + pg.mkPen(QColor(150, 150, 150)) + ) # first reset the pen to avoid problems with length mismatch between the different properties + self.g.scatter.setSize(10) + if len(self._pos) == 0 or self.view_direction == "vertical": + pos_data = self._pos + else: + pos_data = np.flip(self._pos, axis=1) + + self.g.setData( + pos=pos_data, + adj=self.adj, + symbol=self.symbols, + symbolBrush=self.symbolBrush, + pen=self.pen, + data=self.node_ids, + ) + self.g.scatter.setPen(self.outline_pen) + self.g.scatter.setSize(self.sizes) + self.autoRange() + + def _create_pyqtgraph_content(self, track_df: pd.DataFrame) -> None: + """Parse the given track_df into the format that pyqtgraph expects + and save the information as attributes. + + Args: + track_df (pd.DataFrame): The dataframe containing the graph to be + rendered in the tree view. Can be all lineages or a subset. + """ + self._pos = [] + self.adj = [] + self.symbols = [] + self.symbolBrush = [] + self.pen = [] + self.sizes = [] + self.node_ids = [] + + if track_df is not None and not track_df.empty: + self.symbols = track_df["symbol"].to_list() + self.symbolBrush = track_df["color"].to_numpy() + self._pos = track_df[["x_axis_pos", "t"]].to_numpy() + self.node_ids = track_df["node_id"].to_list() + self.sizes = np.array( + [ + 8, + ] + * len(self.symbols) + ) + + valid_edges_df = track_df[track_df["parent_id"] != 0] + node_ids_to_index = { + node_id: index for index, node_id in enumerate(self.node_ids) + } + edges_df = valid_edges_df[["node_id", "parent_id"]] + self.pen = valid_edges_df["color"].to_numpy() + edges_df_mapped = edges_df.map(lambda _id: node_ids_to_index[_id]) + self.adj = edges_df_mapped.to_numpy() + + self.outline_pen = np.array( + [pg.mkPen(QColor(150, 150, 150)) for i in range(len(self._pos))] + ) + + def set_selection(self, selected_nodes: list[Any]) -> None: + """Set the provided list of nodes to be selected. Increases the size + and highlights the outline with blue. Also centers the view + if the first selected node is not visible in the current canvas. + + Args: + selected_nodes (list[Any]): A list of node ids to be selected. + """ + + # reset to default size and color to avoid problems with the array lengths + self.g.scatter.setPen(pg.mkPen(QColor(150, 150, 150))) + self.g.scatter.setSize(10) + + size = ( + self.sizes.copy() + ) # just copy the size here to keep the original self.sizes intact + + outlines = self.outline_pen.copy() + for i, node_id in enumerate(selected_nodes): + node_df = self.track_df.loc[self.track_df["node_id"] == node_id] + if not node_df.empty: + x_axis_pos = node_df["x_axis_pos"].values[0] + t = node_df["t"].values[0] + + # Update size and outline + index = self.node_ids.index(node_id) + size[index] += 5 + outlines[index] = pg.mkPen(color="c", width=2) + + # Center view based on the first selected node + if i == 0: + self._center_view(x_axis_pos, t) + + self.g.scatter.setPen(outlines) + self.g.scatter.setSize(size) + + def _center_view(self, center_x: int, center_y: int): + """Center the Viewbox on given coordinates""" + + if self.view_direction == "horizontal": + center_x, center_y = ( + center_y, + center_x, + ) # flip because the axes have changed in horizontal mode + + view_box = self.plotItem.getViewBox() + current_range = view_box.viewRange() + + x_range = current_range[0] + y_range = current_range[1] + + # Check if the new center is within the current range + if ( + x_range[0] <= center_x <= x_range[1] + and y_range[0] <= center_y <= y_range[1] + ): + return + + # Calculate the width and height of the current view + current_width = x_range[1] - x_range[0] + current_height = y_range[1] - y_range[0] + + # Calculate new ranges maintaining the current width and height + new_x_range = ( + center_x - current_width / 2, + center_x + current_width / 2, + ) + new_y_range = ( + center_y - current_height / 2, + center_y + current_height / 2, + ) + + view_box.setRange(xRange=new_x_range, yRange=new_y_range, padding=0) + + +class TreeWidget(QWidget): + """pyqtgraph-based widget for lineage tree visualization and navigation""" + + def __init__(self, viewer: napari.Viewer): + super().__init__() + self.track_df = pd.DataFrame() # all tracks + self.lineage_df = ( + pd.DataFrame() + ) # the currently viewed subset of lineages + self.graph = None + self.mode = "all" # options: "all", "lineage" + self.view_direction = "vertical" # options: "horizontal", "vertical" + + self.tracks_viewer = TracksViewer.get_instance(viewer) + self.selected_nodes = self.tracks_viewer.selected_nodes + self.selected_nodes.list_updated.connect(self._update_selected) + self.tracks_viewer.tracks_updated.connect(self._update_track_data) + + # Construct the tree view pyqtgraph widget + layout = QVBoxLayout() + + self.tree_widget: TreePlot = TreePlot() + self.tree_widget.node_clicked.connect(self.selected_nodes.add) + + # Add radiobuttons for switching between different display modes + self.mode_widget = TreeViewModeWidget() + self.mode_widget.change_mode.connect(self._set_mode) + + # Add navigation widget + self.navigation_widget = NavigationWidget( + self.track_df, + self.lineage_df, + self.view_direction, + self.selected_nodes, + ) + + # Construct a toolbar and set main layout + panel_layout = QHBoxLayout() + panel_layout.addWidget(self.mode_widget) + panel_layout.addWidget(self.navigation_widget) + + panel = QWidget() + panel.setLayout(panel_layout) + panel.setMaximumWidth(520) + + layout.addWidget(panel) + layout.addWidget(self.tree_widget) + + self.setLayout(layout) + self._update_track_data() + + def keyPressEvent(self, event: QKeyEvent) -> None: + """Catch arrow key presses to navigate in the tree + + Args: + event (QKeyEvent): The Qt Key event + """ + direction_map = { + Qt.Key_Left: "left", + Qt.Key_Right: "right", + Qt.Key_Up: "up", + Qt.Key_Down: "down", + } + + if event.key() == Qt.Key_L: + self.mode_widget._toggle_display_mode() + elif event.key() == Qt.Key_X: # only allow mouse zoom scrolling in X + self.tree_widget.setMouseEnabled(x=True, y=False) + elif event.key() == Qt.Key_Y: # only allow mouse zoom scrolling in Y + self.tree_widget.setMouseEnabled(x=False, y=True) + else: + if event.key() not in direction_map: + return + self.navigation_widget.move(direction_map[event.key()]) + + def keyReleaseEvent(self, ev): + """Reset the mouse scrolling when releasing the X/Y key""" + + if ev.key() == Qt.Key_X or ev.key() == Qt.Key_Y: + self.tree_widget.setMouseEnabled(x=True, y=True) + + def _update_selected(self): + """Called whenever the selection list is updated. Only re-computes + the full graph information when the new selection is not in the + lineage df (and in lineage mode) + """ + if self.mode == "lineage" and any( + node not in np.unique(self.lineage_df["node_id"].values) + for node in self.selected_nodes + ): + self._update_lineage_df() + self.tree_widget.update( + self.lineage_df, self.view_direction, self.selected_nodes + ) + else: + self.tree_widget.set_selection(self.selected_nodes) + + def _update_track_data(self) -> None: + """Called when the TracksViewer emits the tracks_updated signal, indicating + that a new set of tracks should be viewed. + """ + if self.tracks_viewer.tracks is None: + self.track_df = pd.DataFrame() + self.graph = None + else: + self.track_df = extract_sorted_tracks( + self.tracks_viewer.tracks, self.tracks_viewer.colormap + ) + self.graph = self.tracks_viewer.tracks.graph + + self.lineage_df = pd.DataFrame() + # also update the navigation widget + self.navigation_widget.track_df = self.track_df + self.navigation_widget.lineage_df = self.lineage_df + + # set mode back to all and view to vertical + self._set_mode("all") + self.tree_widget.update( + self.track_df, self.view_direction, self.selected_nodes + ) + + def _set_mode(self, mode: str) -> None: + """Set the display mode to all or lineage view. Currently, linage + view is always horizontal and all view is always vertical. + + Args: + mode (str): The mode to set the view to. Options are "all" or "lineage" + """ + if mode not in ["all", "lineage"]: + raise ValueError(f"Mode must be 'all' or 'lineage', got {mode}") + + self.mode = mode + if mode == "all": + self.view_direction = "vertical" + df = self.track_df + elif mode == "lineage": + self.view_direction = "horizontal" + self._update_lineage_df() + df = self.lineage_df + self.navigation_widget.view_direction = self.view_direction + self.tree_widget.update(df, self.view_direction, self.selected_nodes) + + def _update_lineage_df(self) -> None: + """Subset dataframe to include only nodes belonging to the current lineage""" + visible = [] + for node_id in self.selected_nodes: + visible += extract_lineage_tree(self.graph, node_id) + self.lineage_df = self.track_df[ + self.track_df["node_id"].isin(visible) + ].reset_index() + self.lineage_df["x_axis_pos"] = ( + self.lineage_df["x_axis_pos"].rank(method="dense").astype(int) - 1 + ) + self.navigation_widget.lineage_df = self.lineage_df diff --git a/tests/backend/test_motile_run.py b/tests/backend/test_motile_run.py new file mode 100644 index 0000000..1eb670e --- /dev/null +++ b/tests/backend/test_motile_run.py @@ -0,0 +1,27 @@ +import networkx as nx +import numpy as np +from motile_plugin.backend import MotileRun, SolverParams +from motile_plugin.core import Tracks + + +def test_save_load(tmp_path, graph_2d): + segmentation = np.zeros((10, 10, 10)) + for i in range(10): + segmentation[i][0:5, 0:5] = i + + run_name = "test" + run = MotileRun( + run_name=run_name, + solver_params=SolverParams(), + tracks=Tracks(graph=graph_2d, segmentation=segmentation), + ) + path = run.save(tmp_path) + newrun = MotileRun.load(path) + assert nx.utils.graphs_equal(run.tracks.graph, newrun.tracks.graph) + np.testing.assert_array_equal( + run.tracks.segmentation, newrun.tracks.segmentation + ) + assert run.run_name == newrun.run_name + assert run.time.replace(microsecond=0) == newrun.time + assert run.gaps == newrun.gaps + assert run.solver_params == newrun.solver_params diff --git a/tests/test_solver.py b/tests/backend/test_solver.py similarity index 89% rename from tests/test_solver.py rename to tests/backend/test_solver.py index d8c5c64..1415880 100644 --- a/tests/test_solver.py +++ b/tests/backend/test_solver.py @@ -1,5 +1,5 @@ +from motile_plugin.backend import SolverParams from motile_plugin.backend.solve import solve -from motile_plugin.backend.solver_params import SolverParams # capsys is a pytest fixture that captures stdout and stderr output streams diff --git a/tests/conftest.py b/tests/conftest.py index 48170dc..04bc675 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -71,7 +71,7 @@ def graph_2d(): ( "0_1", { - NodeAttr.POS.value: (50, 50), + NodeAttr.POS.value: [50, 50], NodeAttr.TIME.value: 0, NodeAttr.SEG_ID.value: 1, }, @@ -79,7 +79,7 @@ def graph_2d(): ( "1_1", { - NodeAttr.POS.value: (20, 80), + NodeAttr.POS.value: [20, 80], NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 1, }, @@ -87,7 +87,7 @@ def graph_2d(): ( "1_2", { - NodeAttr.POS.value: (60, 45), + NodeAttr.POS.value: [60, 45], NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 2, }, @@ -117,7 +117,7 @@ def multi_hypothesis_graph_2d(): ( "0_0_1", { - NodeAttr.POS.value: (50, 50), + NodeAttr.POS.value: [50, 50], NodeAttr.TIME.value: 0, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, @@ -126,7 +126,7 @@ def multi_hypothesis_graph_2d(): ( "0_1_1", { - NodeAttr.POS.value: (45, 45), + NodeAttr.POS.value: [45, 45], NodeAttr.TIME.value: 0, NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, @@ -135,7 +135,7 @@ def multi_hypothesis_graph_2d(): ( "1_0_1", { - NodeAttr.POS.value: (20, 80), + NodeAttr.POS.value: [20, 80], NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, @@ -144,7 +144,7 @@ def multi_hypothesis_graph_2d(): ( "1_1_1", { - NodeAttr.POS.value: (15, 75), + NodeAttr.POS.value: [15, 75], NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, @@ -153,7 +153,7 @@ def multi_hypothesis_graph_2d(): ( "1_0_2", { - NodeAttr.POS.value: (60, 45), + NodeAttr.POS.value: [60, 45], NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 2, @@ -162,7 +162,7 @@ def multi_hypothesis_graph_2d(): ( "1_1_2", { - NodeAttr.POS.value: (55, 40), + NodeAttr.POS.value: [55, 40], NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 2, @@ -286,7 +286,7 @@ def graph_3d(): ( "0_1", { - NodeAttr.POS.value: (50, 50, 50), + NodeAttr.POS.value: [50, 50, 50], NodeAttr.TIME.value: 0, NodeAttr.SEG_ID.value: 1, }, @@ -294,7 +294,7 @@ def graph_3d(): ( "1_1", { - NodeAttr.POS.value: (20, 50, 80), + NodeAttr.POS.value: [20, 50, 80], NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 1, }, @@ -302,7 +302,7 @@ def graph_3d(): ( "1_2", { - NodeAttr.POS.value: (60, 50, 45), + NodeAttr.POS.value: [60, 50, 45], NodeAttr.TIME.value: 1, NodeAttr.SEG_ID.value: 2, }, @@ -324,7 +324,7 @@ def multi_hypothesis_graph_3d(): ( "0_0_1", { - NodeAttr.POS.value: (50, 50, 50), + NodeAttr.POS.value: [50, 50, 50], NodeAttr.TIME.value: 0, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, @@ -333,7 +333,7 @@ def multi_hypothesis_graph_3d(): ( "0_1_1", { - NodeAttr.POS.value: (45, 50, 55), + NodeAttr.POS.value: [45, 50, 55], NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, @@ -342,7 +342,7 @@ def multi_hypothesis_graph_3d(): ( "1_0_1", { - NodeAttr.POS.value: (20, 50, 80), + NodeAttr.POS.value: [20, 50, 80], NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 1, @@ -351,7 +351,7 @@ def multi_hypothesis_graph_3d(): ( "1_0_2", { - NodeAttr.POS.value: (60, 50, 45), + NodeAttr.POS.value: [60, 50, 45], NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 0, NodeAttr.SEG_ID.value: 2, @@ -360,7 +360,7 @@ def multi_hypothesis_graph_3d(): ( "1_1_1", { - NodeAttr.POS.value: (15, 50, 70), + NodeAttr.POS.value: [15, 50, 70], NodeAttr.TIME.value: 1, NodeAttr.SEG_HYPO.value: 1, NodeAttr.SEG_ID.value: 1, diff --git a/tests/core/test_tracks.py b/tests/core/test_tracks.py new file mode 100644 index 0000000..6428b5a --- /dev/null +++ b/tests/core/test_tracks.py @@ -0,0 +1,33 @@ +import pytest +from motile_plugin.core import Tracks +from motile_toolbox.candidate_graph import NodeAttr + + +def test_tracks(graph_3d): + tracks = Tracks(graph=graph_3d) + assert tracks.get_location("0_1") == [50, 50, 50] + assert tracks.get_time("0_1") == 0 + assert tracks.get_location("0_1", incl_time=True) == [0, 50, 50, 50] + with pytest.raises(KeyError): + tracks.get_location("0") + + tracks_wrong_attr = Tracks( + graph=graph_3d, time_attr="test", pos_attr="test" + ) + with pytest.raises(KeyError): + tracks_wrong_attr.get_location("0_1") + with pytest.raises(KeyError): + tracks_wrong_attr.get_time("0_1") + + # test multiple position attrs + pos_attr = ("z", "y", "x") + for node in graph_3d.nodes(): + pos = graph_3d.nodes[node][NodeAttr.POS.value] + z, y, x = pos + del graph_3d.nodes[node][NodeAttr.POS.value] + graph_3d.nodes[node]["z"] = z + graph_3d.nodes[node]["y"] = y + graph_3d.nodes[node]["x"] = x + + tracks = Tracks(graph=graph_3d, pos_attr=pos_attr) + assert tracks.get_location("0_1") == [50, 50, 50]