Skip to content

Commit

Permalink
change step_size to 0.8 for faster runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
wasserth committed Feb 15, 2024
1 parent 56a7a68 commit d657c52
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
* input to python api can be a Nifti1Image object or a file path
* upgrade to `nnunetv2>=2.2.1`
* minor edits and bugfixes
* use nnU-Net `step_size=0.8` instead of `0.5` for faster runtime while only decreasing dice by 0.001


## Release 2.0.5
Expand Down
8 changes: 5 additions & 3 deletions tests/test_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def are_logs_similar(last_log, new_log, cols, tolerance_percent=0.04):
cpu_utilization = {}
gpu_utilization = {}

device = "gpu" # "cpu" or "gpu"

for resolution in ["15mm", "3mm"]:
# for resolution in ["3mm"]:
img_dir = base_dir / resolution / "ct"
Expand All @@ -193,7 +195,7 @@ def are_logs_similar(last_log, new_log, cols, tolerance_percent=0.04):
for img_fn in tqdm(img_dir.glob("*.nii.gz")):
fast = resolution == "3mm"
st = time.time()
totalsegmentator(img_fn, pred_dir / img_fn.name, fast=fast, ml=True, device="gpu")
totalsegmentator(img_fn, pred_dir / img_fn.name, fast=fast, ml=True, device=device)
times[resolution].append(time.time()-st)

print("Logging...")
Expand Down Expand Up @@ -229,7 +231,7 @@ def are_logs_similar(last_log, new_log, cols, tolerance_percent=0.04):
"gpu_utilization_15mm", "gpu_utilization_3mm",
"python_version", "torch_version", "nnunet_version",
"cuda_version", "cudnn_version",
"gpu_name"]
"gpu_name", "comment"]
overview_file = Path(f"{base_dir}/overview.xlsx")
if overview_file.exists():
overview = pd.read_excel(overview_file)
Expand All @@ -248,7 +250,7 @@ def are_logs_similar(last_log, new_log, cols, tolerance_percent=0.04):
platform.python_version(), torch.__version__,
importlib.metadata.version("nnunetv2"),
float(torch.version.cuda), int(torch.backends.cudnn.version()),
torch.cuda.get_device_name(0)]
torch.cuda.get_device_name(0), ""]

print("Comparing NEW to PREVIOUS log:")
if are_logs_similar(last_log, new_log, cols):
Expand Down
7 changes: 5 additions & 2 deletions totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,11 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None,
device = torch.device('cuda')
else:
device = torch.device('mps')
step_size = 0.5
# step_size = 0.8 # overall speedup roughly 11%; for fast model no speedup; dice 0.001 worse
# step_size = 0.5
# overall speedup for 15mm model roughly 11% (GPU) and 100% (CPU)
# overall speedup for 3mm model roughly 0% (GPU) and 10% (CPU)
# (dice 0.001 worse on test set -> ok)
step_size = 0.8
disable_tta = not tta
verbose = False
save_probabilities = False
Expand Down

0 comments on commit d657c52

Please sign in to comment.