From 24232a9b3500b9ec6e41891400055d3c1b1a285e Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Tue, 3 Dec 2024 03:39:56 +0800 Subject: [PATCH] Perf: remove redundant checks on data integrity (#4433) Systems are aggregated here https://github.com/deepmodeling/deepmd-kit/blob/f343a3b212edab5525502e0261f3068c0b6fb1f6/deepmd/utils/data_system.py#L802 and later initialized here https://github.com/deepmodeling/deepmd-kit/blob/f343a3b212edab5525502e0261f3068c0b6fb1f6/deepmd/utils/data_system.py#L809-L810 This process will instantiate `DeepmdData` class, and it will perform data integrity checks https://github.com/deepmodeling/deepmd-kit/blob/e695a91ca6f7a1c9c830ab1c58b7b7a05db3da23/deepmd/utils/data.py#L80-L82 Besides, the checking process enumerates all items for all ranks, which is unnecessary and quite slow. So this PR removes this check. ## Summary by CodeRabbit - **New Features** - Enhanced flexibility in defining test sizes by allowing percentage input for the `test_size` parameter. - Introduced a new method to automatically compute test sizes based on the specified percentage of total data. - Improved path handling to accept both string and Path inputs, enhancing usability. - **Bug Fixes** - Improved error handling for invalid paths, ensuring users receive clear feedback when files are not found. - **Deprecation Notice** - The `get_test` method is now deprecated, with new logic implemented for loading test data when necessary. --------- Signed-off-by: Chun Cai Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> (cherry picked from commit 3917cf0be777f548b9bc4da63af97a78e8e9e952) --- deepmd/utils/data.py | 2 ++ deepmd/utils/data_system.py | 22 ++-------------------- deepmd/utils/path.py | 12 ++++++------ 3 files changed, 10 insertions(+), 26 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 493a9d8d54..39af73cab3 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -60,6 +60,8 @@ def __init__( ) -> None: """Constructor.""" root = DPPath(sys_path) + if not root.is_dir(): + raise FileNotFoundError(f"System {sys_path} is not found!") self.dirs = root.glob(set_prefix + ".*") if not len(self.dirs): raise FileNotFoundError(f"No {set_prefix}.* is found in {sys_path}") diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 0e960d0ba1..445461b387 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -28,9 +28,6 @@ from deepmd.utils.out_stat import ( compute_stats_from_redu, ) -from deepmd.utils.path import ( - DPPath, -) log = logging.getLogger(__name__) @@ -103,6 +100,8 @@ def __init__( del rcut self.system_dirs = systems self.nsystems = len(self.system_dirs) + if self.nsystems <= 0: + raise ValueError("No systems provided") self.data_systems = [] for ii in self.system_dirs: self.data_systems.append( @@ -755,23 +754,6 @@ def process_systems(systems: Union[str, list[str]]) -> list[str]: systems = expand_sys_str(systems) elif isinstance(systems, list): systems = systems.copy() - help_msg = "Please check your setting for data systems" - # check length of systems - if len(systems) == 0: - msg = "cannot find valid a data system" - log.fatal(msg) - raise OSError(msg, help_msg) - # roughly check all items in systems are valid - for ii in systems: - ii = DPPath(ii) - if not ii.is_dir(): - msg = f"dir {ii} is not a valid dir" - log.fatal(msg) - raise OSError(msg, help_msg) - if not (ii / "type.raw").is_file(): - msg = f"dir {ii} is not a valid data system dir" - log.fatal(msg) - raise OSError(msg, help_msg) return systems diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 6c52caac1d..c542ccf661 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -14,6 +14,7 @@ from typing import ( ClassVar, Optional, + Union, ) import h5py @@ -157,19 +158,16 @@ class DPOSPath(DPPath): Parameters ---------- - path : str + path : Union[str, Path] path mode : str, optional mode, by default "r" """ - def __init__(self, path: str, mode: str = "r") -> None: + def __init__(self, path: Union[str, Path], mode: str = "r") -> None: super().__init__() self.mode = mode - if isinstance(path, Path): - self.path = path - else: - self.path = Path(path) + self.path = Path(path) def load_numpy(self) -> np.ndarray: """Load NumPy array. @@ -300,6 +298,8 @@ def __init__(self, path: str, mode: str = "r") -> None: # so we do not support file names containing #... s = path.split("#") self.root_path = s[0] + if not os.path.isfile(self.root_path): + raise FileNotFoundError(f"{self.root_path} not found") self.root = self._load_h5py(s[0], mode) # h5 path: default is the root path self._name = s[1] if len(s) > 1 else "/"