From a3672f216aa162f2549d1712fad0118b2cc98d49 Mon Sep 17 00:00:00 2001 From: Valerij Talagayev <82884038+talagayev@users.noreply.github.com> Date: Tue, 17 Dec 2024 02:23:19 +0100 Subject: [PATCH] Implementation of Parallelization to `MDAnalysis.analysis.contacts` (#4820) * Fixes #4660 * summary of changes: - added backends and aggregators to Contacts in analysis.contacts - added private _get_box_func method because lambdas cannot be used for parallelization - added the client_Contacts in conftest.py - added client_Contacts in run() in test_contacts.py * Update CHANGELOG --- package/CHANGELOG | 2 + package/MDAnalysis/analysis/contacts.py | 48 ++++++++-- .../MDAnalysisTests/analysis/conftest.py | 8 ++ .../MDAnalysisTests/analysis/test_contacts.py | 87 ++++++++++++------- 4 files changed, 105 insertions(+), 40 deletions(-) diff --git a/package/CHANGELOG b/package/CHANGELOG index 5a1510d9ce2..03255eb4ad3 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -25,10 +25,12 @@ Fixes the function to prevent shared state. (Issue #4655) Enhancements + * Enables parallelization for analysis.contacts.Contacts (Issue #4660) * Enable parallelization for analysis.nucleicacids.NucPairDist (Issue #4670) * Add check and warning for empty (all zero) coordinates in RDKit converter (PR #4824) * Added `precision` for XYZWriter (Issue #4775, PR #4771) + Changes Deprecations diff --git a/package/MDAnalysis/analysis/contacts.py b/package/MDAnalysis/analysis/contacts.py index 7a7e195f09a..f29fd4961e8 100644 --- a/package/MDAnalysis/analysis/contacts.py +++ b/package/MDAnalysis/analysis/contacts.py @@ -223,7 +223,7 @@ def is_any_closer(r, r0, dist=2.5): from MDAnalysis.lib.util import openany from MDAnalysis.analysis.distances import distance_array from MDAnalysis.core.groups import AtomGroup, UpdatingAtomGroup -from .base import AnalysisBase +from .base import AnalysisBase, ResultsGroup logger = logging.getLogger("MDAnalysis.analysis.contacts") @@ -376,8 +376,22 @@ class Contacts(AnalysisBase): :class:`MDAnalysis.analysis.base.Results` instance. .. versionchanged:: 2.2.0 :class:`Contacts` accepts both AtomGroup and string for `select` + .. versionchanged:: 2.9.0 + Introduced :meth:`get_supported_backends` allowing + for parallel execution on :mod:`multiprocessing` + and :mod:`dask` backends. """ + _analysis_algorithm_is_parallelizable = True + + @classmethod + def get_supported_backends(cls): + return ( + "serial", + "multiprocessing", + "dask", + ) + def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5, pbc=True, kwargs=None, **basekwargs): """ @@ -444,11 +458,8 @@ def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5, self.r0 = [] self.initial_contacts = [] - #get dimension of box if pbc set to True - if self.pbc: - self._get_box = lambda ts: ts.dimensions - else: - self._get_box = lambda ts: None + # get dimensions via partial for parallelization compatibility + self._get_box = functools.partial(self._get_box_func, pbc=self.pbc) if isinstance(refgroup[0], AtomGroup): refA, refB = refgroup @@ -464,7 +475,6 @@ def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5, self.n_initial_contacts = self.initial_contacts[0].sum() - @staticmethod def _get_atomgroup(u, sel): select_error_message = ("selection must be either string or a " @@ -480,6 +490,28 @@ def _get_atomgroup(u, sel): else: raise TypeError(select_error_message) + @staticmethod + def _get_box_func(ts, pbc): + """Retrieve the dimensions of the simulation box based on PBC. + + Parameters + ---------- + ts : Timestep + The current timestep of the simulation, which contains the + box dimensions. + pbc : bool + A flag indicating whether periodic boundary conditions (PBC) + are enabled. If `True`, the box dimensions are returned, + else returns `None`. + + Returns + ------- + box_dimensions : ndarray or None + The dimensions of the simulation box as a NumPy array if PBC + is True, else returns `None`. + """ + return ts.dimensions if pbc else None + def _prepare(self): self.results.timeseries = np.empty((self.n_frames, len(self.r0)+1)) @@ -506,6 +538,8 @@ def timeseries(self): warnings.warn(wmsg, DeprecationWarning) return self.results.timeseries + def _get_aggregator(self): + return ResultsGroup(lookup={'timeseries': ResultsGroup.ndarray_vstack}) def _new_selections(u_orig, selections, frame): """create stand alone AGs from selections at frame""" diff --git a/testsuite/MDAnalysisTests/analysis/conftest.py b/testsuite/MDAnalysisTests/analysis/conftest.py index a60b565f1c6..91e9bd760b8 100644 --- a/testsuite/MDAnalysisTests/analysis/conftest.py +++ b/testsuite/MDAnalysisTests/analysis/conftest.py @@ -15,6 +15,7 @@ HydrogenBondAnalysis, ) from MDAnalysis.analysis.nucleicacids import NucPairDist +from MDAnalysis.analysis.contacts import Contacts from MDAnalysis.lib.util import is_installed @@ -149,3 +150,10 @@ def client_HydrogenBondAnalysis(request): @pytest.fixture(scope="module", params=params_for_cls(NucPairDist)) def client_NucPairDist(request): return request.param + + +# MDAnalysis.analysis.contacts + +@pytest.fixture(scope="module", params=params_for_cls(Contacts)) +def client_Contacts(request): + return request.param diff --git a/testsuite/MDAnalysisTests/analysis/test_contacts.py b/testsuite/MDAnalysisTests/analysis/test_contacts.py index 85546cbc3f5..6b416e27f8e 100644 --- a/testsuite/MDAnalysisTests/analysis/test_contacts.py +++ b/testsuite/MDAnalysisTests/analysis/test_contacts.py @@ -171,8 +171,8 @@ def universe(): return mda.Universe(PSF, DCD) def _run_Contacts( - self, universe, - start=None, stop=None, step=None, **kwargs + self, universe, client_Contacts, start=None, + stop=None, step=None, **kwargs ): acidic = universe.select_atoms(self.sel_acidic) basic = universe.select_atoms(self.sel_basic) @@ -181,7 +181,8 @@ def _run_Contacts( select=(self.sel_acidic, self.sel_basic), refgroup=(acidic, basic), radius=6.0, - **kwargs).run(start=start, stop=stop, step=step) + **kwargs + ).run(**client_Contacts, start=start, stop=stop, step=step) @pytest.mark.parametrize("seltxt", [sel_acidic, sel_basic]) def test_select_valid_types(self, universe, seltxt): @@ -195,7 +196,7 @@ def test_select_valid_types(self, universe, seltxt): assert ag_from_string == ag_from_ag - def test_contacts_selections(self, universe): + def test_contacts_selections(self, universe, client_Contacts): """Test if Contacts can take both string and AtomGroup as selections. """ aga = universe.select_atoms(self.sel_acidic) @@ -210,8 +211,8 @@ def test_contacts_selections(self, universe): refgroup=(aga, agb) ) - cag.run() - csel.run() + cag.run(**client_Contacts) + csel.run(**client_Contacts) assert cag.grA == csel.grA assert cag.grB == csel.grB @@ -228,26 +229,31 @@ def test_select_wrong_types(self, universe, ag): ) as te: contacts.Contacts._get_atomgroup(universe, ag) - def test_startframe(self, universe): + def test_startframe(self, universe, client_Contacts): """test_startframe: TestContactAnalysis1: start frame set to 0 (resolution of Issue #624) """ - CA1 = self._run_Contacts(universe) + CA1 = self._run_Contacts(universe, client_Contacts=client_Contacts) assert len(CA1.results.timeseries) == universe.trajectory.n_frames - def test_end_zero(self, universe): + def test_end_zero(self, universe, client_Contacts): """test_end_zero: TestContactAnalysis1: stop frame 0 is not ignored""" - CA1 = self._run_Contacts(universe, stop=0) + CA1 = self._run_Contacts( + universe, client_Contacts=client_Contacts, stop=0 + ) assert len(CA1.results.timeseries) == 0 - def test_slicing(self, universe): + def test_slicing(self, universe, client_Contacts): start, stop, step = 10, 30, 5 - CA1 = self._run_Contacts(universe, start=start, stop=stop, step=step) + CA1 = self._run_Contacts( + universe, client_Contacts=client_Contacts, + start=start, stop=stop, step=step + ) frames = np.arange(universe.trajectory.n_frames)[start:stop:step] assert len(CA1.results.timeseries) == len(frames) - def test_villin_folded(self): + def test_villin_folded(self, client_Contacts): # one folded, one unfolded f = mda.Universe(contacts_villin_folded) u = mda.Universe(contacts_villin_unfolded) @@ -259,12 +265,12 @@ def test_villin_folded(self): select=(sel, sel), refgroup=(grF, grF), method="soft_cut") - q.run() + q.run(**client_Contacts) results = soft_cut(f, u, sel, sel) assert_allclose(q.results.timeseries[:, 1], results[:, 1], rtol=0, atol=1.5e-7) - def test_villin_unfolded(self): + def test_villin_unfolded(self, client_Contacts): # both folded f = mda.Universe(contacts_villin_folded) u = mda.Universe(contacts_villin_folded) @@ -276,13 +282,13 @@ def test_villin_unfolded(self): select=(sel, sel), refgroup=(grF, grF), method="soft_cut") - q.run() + q.run(**client_Contacts) results = soft_cut(f, u, sel, sel) assert_allclose(q.results.timeseries[:, 1], results[:, 1], rtol=0, atol=1.5e-7) - def test_hard_cut_method(self, universe): - ca = self._run_Contacts(universe) + def test_hard_cut_method(self, universe, client_Contacts): + ca = self._run_Contacts(universe, client_Contacts=client_Contacts) expected = [1., 0.58252427, 0.52427184, 0.55339806, 0.54368932, 0.54368932, 0.51456311, 0.46601942, 0.48543689, 0.52427184, 0.46601942, 0.58252427, 0.51456311, 0.48543689, 0.48543689, @@ -306,7 +312,7 @@ def test_hard_cut_method(self, universe): assert len(ca.results.timeseries) == len(expected) assert_allclose(ca.results.timeseries[:, 1], expected, rtol=0, atol=1.5e-7) - def test_radius_cut_method(self, universe): + def test_radius_cut_method(self, universe, client_Contacts): acidic = universe.select_atoms(self.sel_acidic) basic = universe.select_atoms(self.sel_basic) r = contacts.distance_array(acidic.positions, basic.positions) @@ -316,15 +322,20 @@ def test_radius_cut_method(self, universe): r = contacts.distance_array(acidic.positions, basic.positions) expected.append(contacts.radius_cut_q(r[initial_contacts], None, radius=6.0)) - ca = self._run_Contacts(universe, method='radius_cut') + ca = self._run_Contacts( + universe, client_Contacts=client_Contacts, method="radius_cut" + ) assert_array_equal(ca.results.timeseries[:, 1], expected) @staticmethod def _is_any_closer(r, r0, dist=2.5): return np.any(r < dist) - def test_own_method(self, universe): - ca = self._run_Contacts(universe, method=self._is_any_closer) + def test_own_method(self, universe, client_Contacts): + ca = self._run_Contacts( + universe, client_Contacts=client_Contacts, + method=self._is_any_closer + ) bound_expected = [1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1., @@ -340,13 +351,20 @@ def test_own_method(self, universe): def _weird_own_method(r, r0): return 'aaa' - def test_own_method_no_array_cast(self, universe): + def test_own_method_no_array_cast(self, universe, client_Contacts): with pytest.raises(ValueError): - self._run_Contacts(universe, method=self._weird_own_method, stop=2) - - def test_non_callable_method(self, universe): + self._run_Contacts( + universe, + client_Contacts=client_Contacts, + method=self._weird_own_method, + stop=2, + ) + + def test_non_callable_method(self, universe, client_Contacts): with pytest.raises(ValueError): - self._run_Contacts(universe, method=2, stop=2) + self._run_Contacts( + universe, client_Contacts=client_Contacts, method=2, stop=2 + ) @pytest.mark.parametrize("pbc,expected", [ (True, [1., 0.43138152, 0.3989021, 0.43824337, 0.41948765, @@ -354,7 +372,7 @@ def test_non_callable_method(self, universe): (False, [1., 0.42327791, 0.39192399, 0.40950119, 0.40902613, 0.42470309, 0.41140143, 0.42897862, 0.41472684, 0.38574822]) ]) - def test_distance_box(self, pbc, expected): + def test_distance_box(self, pbc, expected, client_Contacts): u = mda.Universe(TPR, XTC) sel_basic = "(resname ARG LYS)" sel_acidic = "(resname ASP GLU)" @@ -363,13 +381,15 @@ def test_distance_box(self, pbc, expected): r = contacts.Contacts(u, select=(sel_acidic, sel_basic), refgroup=(acidic, basic), radius=6.0, pbc=pbc) - r.run() + r.run(**client_Contacts) assert_allclose(r.results.timeseries[:, 1], expected,rtol=0, atol=1.5e-7) - def test_warn_deprecated_attr(self, universe): + def test_warn_deprecated_attr(self, universe, client_Contacts): """Test for warning message emitted on using deprecated `timeseries` attribute""" - CA1 = self._run_Contacts(universe, stop=1) + CA1 = self._run_Contacts( + universe, client_Contacts=client_Contacts, stop=1 + ) wmsg = "The `timeseries` attribute was deprecated in MDAnalysis" with pytest.warns(DeprecationWarning, match=wmsg): assert_equal(CA1.timeseries, CA1.results.timeseries) @@ -385,10 +405,11 @@ def test_n_initial_contacts(self, datafiles, expected): r = contacts.Contacts(u, select=select, refgroup=refgroup) assert_equal(r.n_initial_contacts, expected) -def test_q1q2(): + +def test_q1q2(client_Contacts): u = mda.Universe(PSF, DCD) q1q2 = contacts.q1q2(u, 'name CA', radius=8) - q1q2.run() + q1q2.run(**client_Contacts) q1_expected = [1., 0.98092643, 0.97366031, 0.97275204, 0.97002725, 0.97275204, 0.96276113, 0.96730245, 0.9582198, 0.96185286,