| |
|
|
| import logging |
| from itertools import chain |
|
|
| import networkx as nx |
| import numpy as np |
| import scipy |
| from tqdm import tqdm |
|
|
| from .track_graph import TrackGraph |
| from typing import Optional, Tuple |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def copy_edge(edge: tuple, source: nx.DiGraph, target: nx.DiGraph): |
| if edge[0] not in target.nodes: |
| target.add_node(edge[0], **source.nodes[edge[0]]) |
| if edge[1] not in target.nodes: |
| target.add_node(edge[1], **source.nodes[edge[1]]) |
| target.add_edge(edge[0], edge[1], **source.edges[(edge[0], edge[1])]) |
|
|
|
|
| def track_greedy( |
| candidate_graph, |
| allow_divisions=True, |
| threshold=0.5, |
| edge_attr="weight", |
| ): |
| """Greedy matching, global. |
| |
| Iterates over global edges sorted by weight, and keeps edge if feasible and weight above threshold. |
| |
| Args: |
| allow_divisions (bool, optional): |
| Whether to model divisions. Defaults to True. |
| |
| Returns: |
| solution_graph: NetworkX graph of tracks |
| """ |
| logger.info("Running greedy tracker") |
|
|
| solution_graph = nx.DiGraph() |
|
|
|
|
| edges = candidate_graph.edges(data=True) |
| edges = sorted( |
| edges, |
| key=lambda edge: edge[2][edge_attr], |
| reverse=True, |
| ) |
|
|
| for edge in tqdm(edges, desc="Greedily matched edges"): |
| node_in, node_out, features = edge |
| assert ( |
| features[edge_attr] <= 1.0 |
| ), "Edge weights are assumed to be normalized to [0,1]" |
| |
| if features[edge_attr] < threshold: |
| break |
| |
| |
| if node_out in solution_graph.nodes and solution_graph.in_degree(node_out) > 0: |
| |
| continue |
| if node_in in solution_graph and solution_graph.out_degree(node_in) >= ( |
| 2 if allow_divisions else 1 |
| ): |
| |
| continue |
| |
| copy_edge(edge, candidate_graph, solution_graph) |
|
|
|
|
| return solution_graph |
|
|
|
|
|
|
| def build_graph( |
| nodes: dict, |
| weights: Optional[tuple] = None, |
| use_distance: bool = False, |
| max_distance: Optional[int] = None, |
| max_neighbors: Optional[int] = None, |
| delta_t=1, |
| ): |
| logger.info(f"Build candidate graph with {delta_t=}") |
| G = nx.DiGraph() |
|
|
| for node in nodes: |
| G.add_node( |
| node["id"], |
| time=node["time"], |
| label=node["label"], |
| coords=node["coords"], |
| |
| weight=1, |
| ) |
|
|
| if use_distance: |
| weights = None |
| if weights: |
| weights = {w[0]: w[1] for w in weights} |
|
|
| graph = TrackGraph(G, frame_attribute="time") |
| frame_pairs = zip( |
| chain(*[ |
| list(range(graph.t_begin, graph.t_end - d)) for d in range(1, delta_t + 1) |
| ]), |
| chain(*[ |
| list(range(graph.t_begin + d, graph.t_end)) for d in range(1, delta_t + 1) |
| ]), |
| ) |
| iterator = tqdm( |
| frame_pairs, |
| total=(graph.t_end - graph.t_begin) * delta_t, |
| leave=False, |
| ) |
| for t_begin, t_end in iterator: |
| n_edges_t = len(G.edges) |
| ni, nj = graph.nodes_by_frame(t_begin), graph.nodes_by_frame(t_end) |
| pi = [] |
| for _ni in ni: |
| pi.append(np.array(G.nodes[_ni]["coords"])) |
| pi = np.stack(pi) |
| pj = [] |
| for _nj in nj: |
| pj.append(np.array(G.nodes[_nj]["coords"])) |
| pj = np.stack(pj) |
|
|
| dists = scipy.spatial.distance.cdist(pi, pj) |
|
|
| for _i, _ni in enumerate(ni): |
| inds = np.argsort(dists[_i]) |
| neighbors = 0 |
| for _j, _nj in zip(inds, np.array(nj)[inds]): |
| if max_neighbors and neighbors >= max_neighbors: |
| break |
| dist = dists[_i, _j] |
| if max_distance is None or dist <= max_distance: |
| if weights is None: |
| G.add_edge(_ni, _nj, weight=1 - dist / max_distance) |
| neighbors += 1 |
| else: |
| if (_ni, _nj) in weights: |
| G.add_edge(_ni, _nj, weight=weights[(_ni, _nj)]) |
| neighbors += 1 |
|
|
| e_added = len(G.edges) - n_edges_t |
| if e_added == 0: |
| logger.warning(f"No candidate edges in frame {t_begin}") |
| iterator.set_description( |
| f"{e_added} edges in frame {t_begin} Total edges: {len(G.edges)}" |
| ) |
|
|
| logger.info(f"Added {len(G.nodes)} vertices, {len(G.edges)} edges") |
|
|
| return G |
|
|