Skip to content

Commit

Permalink
return statistics from python api; add totalseg_get_phase
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Mar 12, 2024
1 parent e8f6fb5 commit de8cc53
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 34 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## Master

* also return statistics from python api
* add `totalseg_get_phase`

## Release 2.1.0
* Bugfix: add flush to DummyFile
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
'totalseg_import_weights=totalsegmentator.bin.totalseg_import_weights:main',
'totalseg_download_weights=totalsegmentator.bin.totalseg_download_weights:main',
'totalseg_setup_manually=totalsegmentator.bin.totalseg_setup_manually:main',
'totalseg_set_license=totalsegmentator.bin.totalseg_set_license:main'
'totalseg_set_license=totalsegmentator.bin.totalseg_set_license:main',
'totalseg_get_phase=totalsegmentator.bin.totalseg_get_phase:main'
],
},
)
113 changes: 113 additions & 0 deletions totalsegmentator/bin/totalseg_get_phase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python
import sys
from pathlib import Path
import argparse
import json
import pickle
from pprint import pprint

import nibabel as nib
import numpy as np

from totalsegmentator.python_api import totalsegmentator
from totalsegmentator.statistics import get_basic_statistics


def pi_time_to_phase(pi_time: float) -> str:
"""
Convert the pi time to a phase and get a probability for the value.
native: 0-10
arterial_early: 10-30
arterial_late: 30-50
portal_venous: 50-100
delayed: 100+
returns: phase, probability
"""
if pi_time < 5:
return "native", 1.0
elif pi_time < 10:
return "native", 0.7
elif pi_time < 20:
return "arterial_early", 0.7
elif pi_time < 30:
return "arterial_early", 1.0
elif pi_time < 50:
return "arterial_late", 1.0
elif pi_time < 60:
return "portal_venous", 0.7
elif pi_time < 90:
return "portal_venous", 1.0
elif pi_time < 100:
return "portal_venous", 0.7
else:
return "delayed", 0.7


def get_ct_contrast_phase(ct_img: nib.Nifti1Image):

organs = ["liver", "spleen", "kidney_left", "kidney_right", "pancreas", "urinary_bladder", "gallbladder",
"heart", "aorta", "inferior_vena_cava", "portal_vein_and_splenic_vein",
"iliac_vena_left", "iliac_vena_right", "iliac_artery_left", "iliac_artery_right",
"pulmonary_vein"]

seg_img, stats = totalsegmentator(ct_img, None, ml=True, fast=True, statistics=True,
roi_subset=None, quiet=False)

features = []
for organ in organs:
features.append(stats[organ]["intensity"])

# todo: adapt
# classifier_path = Path(__file__).parent / "classifier.pkl"
classifier_path = "/mnt/nvme/data/phase_classification/classifiers.pkl"
clfs = pickle.load(open(classifier_path, "rb"))

# ensemble across folds
preds = []
for fold, clf in clfs.items():
preds.append(clf.predict([features])[0])
preds = np.array(preds)
pi_time = round(float(np.mean(preds)), 2)
pi_time_std = round(float(np.std(preds)), 4)

print("Ensemble res:")
print(preds)
# print(f"mean: {pi_time} +/- {pi_time_std}")
print(f"mean: {pi_time} [{preds.min():.1f}-{preds.max():.1f}]")
phase, probability = pi_time_to_phase(pi_time)

return {"pi_time": pi_time, "phase": phase, "probability": probability}


def main():
"""
The the contrast phase of a CT scan. Specifically this script will predict the
pi (post injection) time (in seconds) of a CT scan based on the intensity of different regions
in the image
"""
parser = argparse.ArgumentParser(description="Get CT contrast phase.",
epilog="Written by Jakob Wasserthal. If you use this tool please cite https://pubs.rsna.org/doi/10.1148/ryai.230024")

parser.add_argument("-i", metavar="filepath", dest="input_file",
help="path to CT file",
type=lambda p: Path(p).absolute(), required=True)

parser.add_argument("-o", metavar="filepath", dest="output_file",
help="path to output json file",
type=lambda p: Path(p).absolute(), required=True)

args = parser.parse_args()

res = get_ct_contrast_phase(nib.load(args.input_file))

print("Result:")
pprint(res)

with open(args.output_file, "w") as f:
f.write(json.dumps(res, indent=4))


if __name__ == "__main__":
main()
19 changes: 12 additions & 7 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,13 +514,18 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
# Speed:
# stats on 1.5mm: 37s
# stats on 3.0mm: 4s -> great improvement
stats = None
if statistics:
if not quiet: print("Calculating statistics fast...")
st = time.time()
stats_dir = file_out.parent if multilabel_image else file_out
stats_dir.mkdir(exist_ok=True)
get_basic_statistics(img_pred.get_fdata(), img_in_rsp, stats_dir / "statistics.json", quiet, task_name,
exclude_masks_at_border)
if file_out is not None:
stats_dir = file_out.parent if multilabel_image else file_out
stats_dir.mkdir(exist_ok=True)
stats_file = stats_dir / "statistics.json"
else:
stats_file = None
stats = get_basic_statistics(img_pred.get_fdata(), img_in_rsp, stats_file,
quiet, task_name, exclude_masks_at_border, roi_subset)
if not quiet: print(f" calculated in {time.time()-st:.2f}s")

if resample is not None:
Expand All @@ -529,8 +534,8 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
# Use force_affine otherwise output affine sometimes slightly off (which then is even increased
# by undo_canonical)
img_pred = change_spacing(img_pred, [resample, resample, resample], img_in_shape,
order=0, dtype=np.uint8, nr_cpus=nr_threads_resampling,
force_affine=img_in.affine)
order=0, dtype=np.uint8, nr_cpus=nr_threads_resampling,
force_affine=img_in.affine)

if verbose: print("Undoing canonical...")
img_pred = undo_canonical(img_pred, img_in_orig)
Expand Down Expand Up @@ -644,5 +649,5 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
skin = extract_skin(img_in_orig, nib.load(file_out / "body.nii.gz"))
nib.save(skin, file_out / "skin.nii.gz")

return img_out, img_in_orig
return img_out, img_in_orig, stats

44 changes: 27 additions & 17 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def show_license_info():
sys.exit(1)


def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Path, None], ml=False, nr_thr_resamp=1, nr_thr_saving=6,
def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Path, None]=None, ml=False, nr_thr_resamp=1, nr_thr_saving=6,
fast=False, nora_tag="None", preview=False, task="total", roi_subset=None,
statistics=False, radiomics=False, crop_path=None, body_seg=False,
force_split=False, output_type="nifti", quiet=False, verbose=False, test=0,
Expand All @@ -62,8 +62,8 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
if output is not None:
output = Path(output)
else:
if statistics or radiomics:
raise ValueError("Output path is required for statistics and radiomics.")
if radiomics:
raise ValueError("Output path is required for radiomics.")

nora_tag = "None" if nora_tag is None else nora_tag

Expand Down Expand Up @@ -278,6 +278,7 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
raise ValueError("roi_subset only works with task 'total'")

# Generate rough organ segmentation (6mm) for speed up if crop or roi_subset is used
# (for "fast" on GPU it makes no big difference, but on CPU it can help even for "fast")
if crop is not None or roi_subset is not None:

body_seg = False # can not be used together with body_seg
Expand All @@ -290,7 +291,7 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
else:
crop_model_task = 298
crop_spacing = 6.0
organ_seg, _ = nnUNet_predict_image(input, None, crop_model_task, model="3d_fullres", folds=[0],
organ_seg, _, _ = nnUNet_predict_image(input, None, crop_model_task, model="3d_fullres", folds=[0],
trainer="nnUNetTrainer_4000epochs_NoMirroring", tta=False, multilabel_image=True, resample=crop_spacing,
crop=None, crop_path=None, task_name="total", nora_tag="None", preview=False,
save_binary=False, nr_threads_resampling=nr_thr_resamp, nr_threads_saving=1,
Expand All @@ -313,7 +314,7 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
download_pretrained_weights(300)
st = time.time()
if not quiet: print("Generating rough body segmentation...")
body_seg, _ = nnUNet_predict_image(input, None, 300, model="3d_fullres", folds=[0],
body_seg, _, _ = nnUNet_predict_image(input, None, 300, model="3d_fullres", folds=[0],
trainer="nnUNetTrainer", tta=False, multilabel_image=True, resample=6.0,
crop=None, crop_path=None, task_name="body", nora_tag="None", preview=False,
save_binary=True, nr_threads_resampling=nr_thr_resamp, nr_threads_saving=1,
Expand All @@ -323,15 +324,15 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
if verbose: print(f"Rough body segmentation generated in {time.time()-st:.2f}s")

folds = [0] # None
seg_img, ct_img = nnUNet_predict_image(input, output, task_id, model=model, folds=folds,
trainer=trainer, tta=False, multilabel_image=ml, resample=resample,
crop=crop, crop_path=crop_path, task_name=task, nora_tag=nora_tag, preview=preview,
nr_threads_resampling=nr_thr_resamp, nr_threads_saving=nr_thr_saving,
force_split=force_split, crop_addon=crop_addon, roi_subset=roi_subset,
output_type=output_type, statistics=statistics_fast,
quiet=quiet, verbose=verbose, test=test, skip_saving=skip_saving, device=device,
exclude_masks_at_border=statistics_exclude_masks_at_border,
no_derived_masks=no_derived_masks, v1_order=v1_order)
seg_img, ct_img, stats = nnUNet_predict_image(input, output, task_id, model=model, folds=folds,
trainer=trainer, tta=False, multilabel_image=ml, resample=resample,
crop=crop, crop_path=crop_path, task_name=task, nora_tag=nora_tag, preview=preview,
nr_threads_resampling=nr_thr_resamp, nr_threads_saving=nr_thr_saving,
force_split=force_split, crop_addon=crop_addon, roi_subset=roi_subset,
output_type=output_type, statistics=statistics_fast,
quiet=quiet, verbose=verbose, test=test, skip_saving=skip_saving, device=device,
exclude_masks_at_border=statistics_exclude_masks_at_border,
no_derived_masks=no_derived_masks, v1_order=v1_order)
seg = seg_img.get_fdata().astype(np.uint8)

try:
Expand All @@ -348,8 +349,14 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
if statistics:
if not quiet: print("Calculating statistics...")
st = time.time()
stats_dir = output.parent if ml else output
get_basic_statistics(seg, ct_img, stats_dir / "statistics.json", quiet, task, statistics_exclude_masks_at_border)
if output is not None:
stats_dir = output.parent if ml else output
stats_file = stats_dir / "statistics.json"
else:
stats_file = None
stats = get_basic_statistics(seg, ct_img, stats_file,
quiet, task, statistics_exclude_masks_at_border,
roi_subset)
# get_radiomics_features_for_entire_dir(input, output, output / "statistics_radiomics.json")
if not quiet: print(f" calculated in {time.time()-st:.2f}s")

Expand All @@ -370,4 +377,7 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
get_radiomics_features_for_entire_dir(input_path, output, stats_dir / "statistics_radiomics.json")
if not quiet: print(f" calculated in {time.time()-st:.2f}s")

return seg_img
if statistics or statistics_fast:
return seg_img, stats
else:
return seg_img
30 changes: 22 additions & 8 deletions totalsegmentator/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import json
from functools import partial
import time
from typing import Union

import numpy as np
import pandas as pd
import nibabel as nib
from nibabel.nifti1 import Nifti1Image
from tqdm import tqdm
from p_tqdm import p_map
import numpy.ma as ma
Expand Down Expand Up @@ -90,17 +92,27 @@ def touches_border(mask):
return False


def get_basic_statistics(seg: np.array, ct_file, file_out: Path, quiet: bool = False,
task: str = "total", exclude_masks_at_border: bool = True):
def get_basic_statistics(seg: np.array,
ct_file: Union[Path, Nifti1Image],
file_out: Union[Path, None]=None,
quiet: bool=False,
task: str="total",
exclude_masks_at_border: bool=True,
roi_subset: list=None):
"""
ct_file: path to a ct_file or a nifti file object
"""
ct_img = nib.load(ct_file) if type(ct_file) == pathlib.PosixPath else ct_file
ct = ct_img.get_fdata().astype(np.int16)
spacing = ct_img.header.get_zooms()
vox_vol = spacing[0] * spacing[1] * spacing[2]

class_map_stats = class_map[task]
if roi_subset is not None:
class_map_stats = {k: v for k, v in class_map_stats.items() if v in roi_subset}

stats = {}
for k, mask_name in tqdm(class_map[task].items(), disable=quiet):
for k, mask_name in tqdm(class_map_stats.items(), disable=quiet):
stats[mask_name] = {}
# data = nib.load(mask).get_fdata() # loading: 0.6s
data = seg == k # 0.18s
Expand All @@ -114,8 +126,10 @@ def get_basic_statistics(seg: np.array, ct_file, file_out: Path, quiet: bool = F
# stats[mask_name]["intensity"] = ct[roi_mask > 0].mean().round(2) if roi_mask.sum() > 0 else 0.0 # 3.0s
stats[mask_name]["intensity"] = np.average(ct, weights=roi_mask).round(2) if roi_mask.sum() > 0 else 0.0 # 0.9s

# For nora json is good
# For other people csv might be better -> not really because here only for one subject each -> use json
with open(file_out, "w") as f:
json.dump(stats, f, indent=4)

if file_out is not None:
# For nora json is good
# For other people csv might be better -> not really because here only for one subject each -> use json
with open(file_out, "w") as f:
json.dump(stats, f, indent=4)
else:
return stats

0 comments on commit de8cc53

Please sign in to comment.