Skip to content

Commit

Permalink
allow different resampling for all axis
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Jun 25, 2024
1 parent 6753496 commit c6bb6db
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
11 changes: 7 additions & 4 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
v1_order=False):
"""
crop: string or a nibabel image
resample: None or float (target spacing for all dimensions)
resample: None or float (target spacing for all dimensions) or list of floats
"""
if not isinstance(file_in, Nifti1Image):
file_in = Path(file_in)
Expand All @@ -320,6 +320,9 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
class_map_parts = class_map_parts_mr
map_taskid_to_partname = map_taskid_to_partname_mr

if type(resample) is float:
resample = [resample, resample, resample]

# for debugging
# tmp_dir = file_in.parent / ("nnunet_tmp_" + ''.join(random.Random().choices(string.ascii_uppercase + string.digits, k=8)))
# (tmp_dir).mkdir(exist_ok=True)
Expand Down Expand Up @@ -380,7 +383,7 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
st = time.time()
img_in_shape = img_in.shape
img_in_zooms = img_in.header.get_zooms()
img_in_rsp = change_spacing(img_in, [resample, resample, resample],
img_in_rsp = change_spacing(img_in, resample,
order=3, dtype=np.int32, nr_cpus=nr_threads_resampling) # 4 cpus instead of 1 makes it a bit slower
if verbose:
print(f" from shape {img_in.shape} to shape {img_in_rsp.shape}")
Expand Down Expand Up @@ -413,7 +416,7 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
nib.save(nib.Nifti1Image(img_in_rsp_data[:, :, third*2+1-margin:], img_in_rsp.affine),
tmp_dir / "s03_0000.nii.gz")

if task_name == "total" and resample < 3.0:
if task_name == "total" and resample is not None and resample[0] < 3.0:
# overall speedup for 15mm model roughly 11% (GPU) and 100% (CPU)
# overall speedup for 3mm model roughly 0% (GPU) and 10% (CPU)
# (dice 0.001 worse on test set -> ok)
Expand Down Expand Up @@ -549,7 +552,7 @@ def nnUNet_predict_image(file_in: Union[str, Path, Nifti1Image], file_out, task_
if verbose: print(f" back to original shape: {img_in_shape}")
# Use force_affine otherwise output affine sometimes slightly off (which then is even increased
# by undo_canonical)
img_pred = change_spacing(img_pred, [resample, resample, resample], img_in_shape,
img_pred = change_spacing(img_pred, resample, img_in_shape,
order=0, dtype=np.uint8, nr_cpus=nr_threads_resampling,
force_affine=img_in.affine)

Expand Down
6 changes: 4 additions & 2 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,10 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
folds = None
if fast: raise ValueError("task head_muscles does not work with option --fast")
elif task == "headneck_muscles":
task_id = 778
resample = None
# task_id = 778
# resample = None
task_id = [778, 779]
resample = [1., 0.75, 0.75]
trainer = "nnUNetTrainer_DASegOrd0_NoMirroring"
# crop = ["skull", "clavicula_left", "clavicula_right", "vertebrae_C5", "vertebrae_T1", "vertebrae_T4"]
# crop_addon = [10, 10, 10]
Expand Down

0 comments on commit c6bb6db

Please sign in to comment.