| |
|
|
| import logging |
| from collections import deque |
| from pathlib import Path |
|
|
| import networkx as nx |
| import numpy as np |
| import pandas as pd |
| import tifffile |
| from skimage.measure import regionprops |
| from tqdm import tqdm |
| from typing import List, Optional, Tuple |
|
|
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.INFO) |
|
|
|
|
| |
| |
|
|
| class CtcTracklet: |
| def __init__(self, parent: int, nodes: List[int], start_frame: int) -> None: |
| self.parent = parent |
| self.nodes = nodes |
| self.start_frame = start_frame |
|
|
| def __lt__(self, other): |
| if self.start_frame < other.start_frame: |
| return True |
| if self.start_frame > other.start_frame: |
| return False |
| if self.start_frame == other.start_frame: |
| return self.parent < other.parent |
|
|
| def __str__(self) -> str: |
| return f"Tracklet(parent={self.parent}, nodes={self.nodes})" |
|
|
| def __repr__(self) -> str: |
| return str(self) |
|
|
|
|
| def ctc_tracklets(G: nx.DiGraph, frame_attribute: str = "time") -> List[CtcTracklet]: |
| """Return all CTC tracklets in a graph, i.e. |
| |
| - first node after |
| - a division (out_degree of parent = 2) |
| - an appearance (in_degree=0) |
| - a gap closing event (delta_t to parent node > 1) |
| - inner nodes have in_degree=1 and out_degree=1, delta_t=1 |
| - last node: |
| - before a division (out_degree = 2) |
| - before a disappearance (out_degree = 0) |
| - before a gap closing event (delta_t to next node > 1) |
| """ |
| tracklets = [] |
| |
|
|
| |
| starts = deque() |
| starts.extend([ |
| (p, d) for p in G.nodes for d in G.successors(p) if G.out_degree[p] == 2 |
| ]) |
| |
| starts.extend([(-1, n) for n in G.nodes if G.in_degree[n] == 0]) |
| while starts: |
| _p, _s = starts.popleft() |
| nodes = [_s] |
| |
| c = _s |
| while True: |
| if G.out_degree[c] > 2: |
| raise ValueError("More than two daughters!") |
| if G.out_degree[c] == 2: |
| break |
| if G.out_degree[c] == 0: |
| break |
| t_c = G.nodes[c][frame_attribute] |
| suc = next(iter(G.successors(c))) |
| t_suc = G.nodes[suc][frame_attribute] |
| if t_suc - t_c > 1: |
| logger.debug( |
| f"Gap closing edge from `{c} (t={t_c})` to `{suc} (t={t_suc})`" |
| ) |
| starts.append((c, suc)) |
| break |
| |
| c = next(iter(G.successors(c))) |
| nodes.append(c) |
|
|
| tracklets.append( |
| CtcTracklet( |
| parent=_p, nodes=nodes, start_frame=G.nodes[_s][frame_attribute] |
| ) |
| ) |
|
|
| return tracklets |
|
|
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
|
|
| def _check_ctc_df(df: pd.DataFrame, masks: np.ndarray): |
| """Sanity check of all labels in a CTC dataframe are present in the masks.""" |
| |
| if len(df) == 0 and np.all(masks == 0): |
| return True |
|
|
| for t in range(df.t1.min(), df.t1.max()): |
| sub = df[(df.t1 <= t) & (df.t2 >= t)] |
| sub_lab = set(sub.label) |
| |
| masks_lab = set(np.where(np.bincount(masks[t].ravel()))[0]) - {0} |
| if not sub_lab.issubset(masks_lab): |
| print(f"Missing labels in masks at t={t}: {sub_lab - masks_lab}") |
| return False |
| return True |
|
|
|
|
| def graph_to_edge_table( |
| graph: nx.DiGraph, |
| frame_attribute: str = "time", |
| edge_attribute: str = "weight", |
| outpath: Optional[Path] = None, |
| ) -> pd.DataFrame: |
| """Write edges of a graph to a table. |
| |
| The table has columns `source_frame`, `source_label`, `target_frame`, `target_label`, and `weight`. |
| The first line is a header. The source and target are the labels of the objects in the |
| input masks in the designated frames (0-indexed). |
| |
| Args: |
| graph: With node attributes `frame_attribute`, `edge_attribute` and 'label'. |
| frame_attribute: Name of the frame attribute 'graph`. |
| edge_attribute: Name of the score attribute in `graph`. |
| outpath: If given, save the edges in CSV file format. |
| |
| Returns: |
| pd.DataFrame: Edges DataFrame with columns ['source_frame', 'source', 'target_frame', 'target', 'weight'] |
| """ |
| rows = [] |
| for edge in graph.edges: |
| source = graph.nodes[edge[0]] |
| target = graph.nodes[edge[1]] |
|
|
| source_label = int(source["label"]) |
| source_frame = int(source[frame_attribute]) |
| target_label = int(target["label"]) |
| target_frame = int(target[frame_attribute]) |
| weight = float(graph.edges[edge][edge_attribute]) |
|
|
| rows.append([source_frame, source_label, target_frame, target_label, weight]) |
|
|
| df = pd.DataFrame( |
| rows, |
| columns=[ |
| "source_frame", |
| "source_label", |
| "target_frame", |
| "target_label", |
| "weight", |
| ], |
| ) |
| df = df.sort_values( |
| by=["source_frame", "source_label", "target_frame", "target_label"], |
| ascending=True, |
| ) |
|
|
| if outpath is not None: |
| outpath = Path(outpath) |
| outpath.parent.mkdir( |
| parents=True, |
| exist_ok=True, |
| ) |
|
|
| df.to_csv(outpath, index=False, header=True, sep=",") |
|
|
| return df |
|
|
|
|
| def graph_to_ctc( |
| graph: nx.DiGraph, |
| masks_original: np.ndarray, |
| check: bool = True, |
| frame_attribute: str = "time", |
| outdir: Optional[Path] = None, |
| ) -> Tuple[pd.DataFrame, np.ndarray]: |
| """Convert graph to ctc track Dataframe and relabeled masks. |
| |
| Args: |
| graph: with node attributes `frame_attribute` and "label" |
| masks_original: list of masks with unique labels |
| check: Check CTC format |
| frame_attribute: Name of the frame attribute in the graph nodes. |
| outdir: path to save results in CTC format. |
| |
| Returns: |
| pd.DataFrame: track dataframe with columns ['track_id', 't_start', 't_end', 'parent_id'] |
| np.ndarray: masks with unique color for each track |
| """ |
| |
| tracklets = ctc_tracklets(graph, frame_attribute=frame_attribute) |
|
|
| regions = tuple( |
| dict((reg.label, reg.slice) for reg in regionprops(m)) |
| for t, m in enumerate(masks_original) |
| ) |
|
|
| masks = np.stack([np.zeros_like(m) for m in masks_original]) |
| rows = [] |
| |
| node_to_tracklets = dict({-1: 0}) |
|
|
| |
| for i, _tracklet in tqdm( |
| enumerate(sorted(tracklets)), |
| total=len(tracklets), |
| desc="Converting graph to CTC results", |
| ): |
| _parent = _tracklet.parent |
| _nodes = _tracklet.nodes |
| label = i + 1 |
|
|
| _start, end = _nodes[0], _nodes[-1] |
|
|
| t1 = _tracklet.start_frame |
| |
| t2 = graph.nodes[end][frame_attribute] |
|
|
| node_to_tracklets[end] = label |
|
|
| |
| for _n in _nodes: |
| node = graph.nodes[_n] |
| t = node[frame_attribute] |
| lab = node["label"] |
| ss = regions[t][lab] |
| m = masks_original[t][ss] == lab |
| if masks[t][ss][m].max() > 0: |
| raise RuntimeError(f"Overlapping masks at t={t}, label={lab}") |
| if np.count_nonzero(m) == 0: |
| raise RuntimeError(f"Empty mask at t={t}, label={lab}") |
| masks[t][ss][m] = label |
|
|
| rows.append([label, t1, t2, node_to_tracklets[_parent]]) |
|
|
| df = pd.DataFrame(rows, columns=["label", "t1", "t2", "parent"], dtype=int) |
|
|
| masks = np.stack(masks) |
|
|
| if check: |
| _check_ctc_df(df, masks) |
|
|
| if outdir is not None: |
| outdir = Path(outdir) |
| outdir.mkdir( |
| |
| parents=True, |
| exist_ok=True, |
| ) |
| df.to_csv(outdir / "res_track.txt", index=False, header=False, sep=" ") |
| for i, m in tqdm(enumerate(masks), total=len(masks), desc="Saving masks"): |
| tifffile.imwrite( |
| outdir / f"res_track{i:04d}.tif", |
| m, |
| compression="zstd", |
| ) |
|
|
| return df, masks |
|
|
|
|
| def ctc_to_graph(df: pd.DataFrame, frame_attribute: str = "time"): |
| """From a ctc dataframe, create a digraph with frame_attribute and label as node attributes. |
| |
| Args: |
| df: pd.DataFrame with columns `label`, `t1`, `t2`, `parent` (man_track.txt) |
| frame_attribute: Name of the frame attribute in the graph nodes. |
| |
| Returns: |
| graph: The track graph |
| """ |
| graph = nx.DiGraph() |
|
|
| t1 = df.t1.min() |
| t2 = df.t2.max() |
|
|
| for t in tqdm(range(t1, t2 + 1)): |
| obs = df[(df.t1 <= t) & (df.t2 >= t)] |
| for row in obs.itertuples(): |
| label, t1, t2, parent = row.label, row.t1, row.t2, row.parent |
| |
| if not graph.has_node(label): |
| attrs = {"label": label, frame_attribute: t} |
| graph.add_node(label, **attrs) |
|
|
| if parent != 0: |
| graph.add_edge(parent, label) |
|
|
| return graph |
|
|