Tracking with Trackastra#

We perform tracking with Trackastra; make sure torch is installed with the correct CUDA version, and that trackastra is installed.

If you want to use the better ILP solver, you must install additional libraries; see the installation instructions for more details.

import torch
from trackastra.model import Trackastra
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks
from pathlib import Path
from tifffile import imread
import skimage as ski
import numpy as np
INFO:numexpr.utils:Note: detected 72 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
INFO:numexpr.utils:Note: NumExpr detected 72 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.

Loading the data#

device = "cuda" if torch.cuda.is_available() else "cpu"

# load some test data images and masks
root_path = Path("../../pvc/scratch/interaction_cells/datasets/")
print("Loading image...")
imgs = imread(root_path / "series003_cCAR_tumor.tif")
print("Loaded image")
imgs = imgs[:, 0, ...]
Loading image...
Loaded image

Again, we downsample the images by two to speed up the process.

imgs = ski.transform.resize(imgs, (imgs.shape[0], imgs.shape[1] // 2, imgs.shape[2] // 2), anti_aliasing=True)
masks = imread("../../pvc/scratch/SHARE/cancer_cell_masks/series003_cCAR_tumor.tif").astype(np.uint16)

Tracking#

Load the pre-trained model#

Below we load the pre-trained trackastra model.

# Load a pretrained model
model = Trackastra.from_pretrained("general_2d", device=device)
INFO:trackastra.model.model:Loading model state from /home/achard/.trackastra/.models/general_2d/model.pt
INFO:trackastra.model.model_api:Using device cuda
/home/achard/.trackastra/.models/general_2d already downloaded, skipping.

Computing the tracks#

Below we compute the tracks.

Important

The following paramaters are especially important for tracking:

  • mode: The greedy solver simply takes the most probable option at each point, it is fast but simple. greedy_nodivis the same, but without divisions, which is the case in this data. Finally the ilp solver is the best, but also the slowest. It requires additional libraries to be installed.

  • use_distance: If set, tracks further apart than the distance are not connected. This must be tuned depending on the data.

  • max_distance: The maximum distance between two points to be connected, see above.

  • allow_divisions: If set, the algorithm allows divisions of tracks. Disable this if divisions are not part of the data.

# Track the cells
track_graph = model.track(imgs, masks.astype(np.uint16), mode="ilp", use_distance=True, max_distance=30, allow_divisions=False)  # or mode="ilp", or "greedy_nodiv"
INFO:trackastra.model.model_api:Predicting weights for candidate graph
INFO:trackastra.data.wrfeat:Extracting features from 162 detections
INFO:trackastra.data.wrfeat:Using single process for feature extraction
Extracting features: 100%|██████████| 162/162 [00:12<00:00, 12.50it/s]
INFO:trackastra.model.model_api:Building windows
Building windows: 100%|██████████| 159/159 [00:00<00:00, 14996.84it/s]
INFO:trackastra.model.model_api:Predicting windows
Computing associations: 100%|██████████| 159/159 [00:05<00:00, 29.77it/s]
INFO:trackastra.model.model_api:Running greedy tracker
INFO:trackastra.tracking.tracking:Build candidate graph with delta_t=1
INFO:trackastra.tracking.tracking:Added 26729 vertices, 32401 edges                          
INFO:trackastra.tracking.ilp:Using `gt` ILP config.
INFO:motile.solver:Adding NodeSelection cost...
INFO:motile.solver:Adding NodeSelected variables...
INFO:motile.solver:Adding EdgeSelection cost...
INFO:motile.solver:Adding EdgeSelected variables...
INFO:motile.solver:Adding Appear cost...
INFO:motile.solver:Adding NodeAppear variables...
INFO:motile.solver:Adding Disappear cost...
INFO:motile.solver:Adding NodeDisappear variables...
INFO:motile.solver:Adding MaxParents constraint...
INFO:motile.solver:Adding MaxChildren constraint...
INFO:motile.solver:Computing costs...
INFO:motile.solver:ILP solver returned with: OPTIMAL
Candidate graph		26729 nodes	32401 edges
Solution graph		26729 nodes	25316 edges

Converting the tracks#

Below we save the tracks to disk, and we load them again for additional filtering

# Write to cell tracking challenge format
ctc_tracks, masks_tracked = graph_to_ctc(
      track_graph,
      masks,
      outdir="../../pvc/scratch/SHARE/cancer_cell_masks/series003_cCAR_tumor_tracked",
)
Converting graph to CTC results: 100%|██████████| 1413/1413 [00:00<00:00, 1726.61it/s]
Saving masks: 100%|██████████| 162/162 [00:00<00:00, 181.85it/s]

Track refinement#

from napari_ctc_io.reader import read_ctc, _ctc_to_napari_tracks
from pathlib import Path
import pandas as pd
import numpy as np
import napari
masks, tracks, tracks_graph = read_ctc(
    Path(r"C:\Users\Cyril\Desktop\Code\tumorscope\DATA\tracked")
)
INFO:napari_ctc_io.reader:Loaded tracks from C:\Users\Cyril\Desktop\Code\tumorscope\DATA\tracked\man_track.txt
INFO:napari_ctc_io.reader:Running CTC format checks
WARNING:napari_ctc_io.reader:1 non-connected masks at t=1.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=2.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=3.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=6.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=7.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=8.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=9.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=10.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=11.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=12.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=14.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=15.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=16.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=17.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=19.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=20.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=21.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=22.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=23.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=24.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=25.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=26.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=27.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=29.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=30.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=32.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=33.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=36.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=37.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=38.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=39.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=40.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=43.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=44.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=46.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=48.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=49.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=50.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=51.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=53.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=54.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=55.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=58.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=59.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=60.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=62.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=64.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=65.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=69.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=70.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=71.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=72.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=73.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=74.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=75.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=76.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=77.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=79.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=81.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=82.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=83.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=84.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=85.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=86.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=88.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=89.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=90.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=92.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=95.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=96.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=97.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=98.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=99.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=100.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=101.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=104.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=105.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=109.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=110.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=111.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=113.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=115.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=116.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=117.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=118.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=119.
WARNING:napari_ctc_io.reader:4 non-connected masks at t=120.
WARNING:napari_ctc_io.reader:4 non-connected masks at t=121.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=122.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=123.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=124.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=125.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=126.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=127.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=128.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=129.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=132.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=133.
WARNING:napari_ctc_io.reader:4 non-connected masks at t=134.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=136.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=137.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=141.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=142.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=145.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=146.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=147.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=149.
WARNING:napari_ctc_io.reader:5 non-connected masks at t=150.
WARNING:napari_ctc_io.reader:4 non-connected masks at t=151.
WARNING:napari_ctc_io.reader:5 non-connected masks at t=152.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=153.
WARNING:napari_ctc_io.reader:2 non-connected masks at t=154.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=155.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=157.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=158.
WARNING:napari_ctc_io.reader:1 non-connected masks at t=159.
WARNING:napari_ctc_io.reader:3 non-connected masks at t=161.
INFO:napari_ctc_io.reader:Checks completed
tracks_df = pd.DataFrame(tracks, columns=["label", "frame", "x", "y"])
tracks_df.label = tracks_df.label.astype(np.uint16)
tracks_df.frame = tracks_df.frame.astype(np.uint16)

Filter tracks by time#

Below we discard tracks that are shorter than 5 frames, which helps with false positives. The number of frames can be adjusted depending on the specific data. Note that this function could be modified to filter by other criteria, such as the size of masks, etc.

def filter_masks_by_time(masks, tracks_df, min_length=5):
    """
    Filter masks by their length.
    """
    masks_ids = tracks_df.groupby("label").size()
    valid_masks = masks_ids[masks_ids >= min_length].index.values
    # In the masks array (T, H, W), remove labels with values that are not in valid_masks
    masks_filtered = np.zeros_like(masks, dtype=np.uint16)
    for t in range(masks.shape[0]):
        masks_filtered[t] = np.where(np.isin(masks[t], valid_masks), masks[t], 0)
    tracks_df_filtered = tracks_df[tracks_df["label"].isin(valid_masks)]
    return masks_filtered, tracks_df_filtered
masks_filtered, filtered_tracks_df = filter_masks_by_time(masks, tracks_df, min_length=20)
filtered_tracks_df
label frame x y
0 1 0 26.634961 13.174807
3 4 0 14.884073 104.656250
4 5 0 16.186521 212.166099
5 6 0 7.246180 284.018676
6 7 0 16.390606 348.300783
... ... ... ... ...
26661 1049 161 358.852544 202.169263
26662 1050 161 493.740546 463.212185
26663 1055 161 36.035928 129.623752
26664 1062 161 633.679825 101.527778
26665 1065 161 229.864633 670.640133

23183 rows × 4 columns

Save the filtered tracks#

from tqdm import tqdm
import tifffile

outdir = Path("./tracked")
if not outdir.exists():
    outdir.mkdir(parents=False, exist_ok=True)

filtered_tracks_df.to_csv(outdir / "man_track.txt", index=False, header=False, sep=" ")
for i, m in tqdm(enumerate(masks), total=len(masks), desc="Saving masks"):
    tifffile.imwrite(
        outdir / f"man_track{i:04d}.tif",
        m,
        compression="zstd",
    )
Saving masks: 100%|██████████| 162/162 [00:03<00:00, 49.56it/s]
# v = napari.Viewer()
# v.add_labels(masks, name="masks")
<Tracks layer 'tracks' at 0x2d21a3affd0>
# v.add_tracks(tracks, name="tracks")
<Tracks layer 'tracks [1]' at 0x2d2bcba8f70>
# v.add_labels(masks_filtered, name="masks_filtered")

# filtered_tracks = filtered_tracks_df.to_numpy().astype(np.float32)
# v.add_tracks(filtered_tracks, name="filtered_tracks")
<Tracks layer 'filtered_tracks' at 0x2d2c14816c0>