From 1de45113756dad7c8d2d489fbc98a14d2e4c9f9a Mon Sep 17 00:00:00 2001 From: jakob Date: Wed, 27 Nov 2024 09:17:36 +0100 Subject: [PATCH] make sure torch cudnn benchmark and num_threads setting is not permanently changed by totalsegmentator --- totalsegmentator/python_api.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/totalsegmentator/python_api.py b/totalsegmentator/python_api.py index a7f4d5029..cf3acec3e 100644 --- a/totalsegmentator/python_api.py +++ b/totalsegmentator/python_api.py @@ -91,6 +91,10 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa nora_tag = "None" if nora_tag is None else nora_tag + # Store initial torch settings + initial_cudnn_benchmark = torch.backends.cudnn.benchmark + initial_num_threads = torch.get_num_threads() + validate_device_type_api(device) device = convert_device_to_cuda(device) @@ -555,6 +559,10 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa get_radiomics_features_for_entire_dir(input_path, output, stats_dir / "statistics_radiomics.json") if not quiet: print(f" calculated in {time.time()-st:.2f}s") + # Restore initial torch settings + torch.backends.cudnn.benchmark = initial_cudnn_benchmark + torch.set_num_threads(initial_num_threads) + if statistics or statistics_fast: return seg_img, stats else: