Skip to content

Commit

Permalink
[fmt] transformations (#4809)
Browse files Browse the repository at this point in the history
  • Loading branch information
RMeli authored Nov 30, 2024
1 parent 905f197 commit 557f27d
Show file tree
Hide file tree
Showing 18 changed files with 775 additions and 542 deletions.
4 changes: 2 additions & 2 deletions package/MDAnalysis/transformations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def __init__(self, **kwargs):
analysis approach.
Default is ``True``.
"""
self.max_threads = kwargs.pop('max_threads', None)
self.parallelizable = kwargs.pop('parallelizable', True)
self.max_threads = kwargs.pop("max_threads", None)
self.parallelizable = kwargs.pop("parallelizable", True)

def __call__(self, ts):
"""The function that makes transformation can be called as a function
Expand Down
29 changes: 14 additions & 15 deletions package/MDAnalysis/transformations/boxdimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from .base import TransformationBase


class set_dimensions(TransformationBase):
"""
Set simulation box dimensions.
Expand Down Expand Up @@ -85,33 +86,31 @@ class set_dimensions(TransformationBase):
Added the option to set varying box dimensions (i.e. an NPT trajectory).
"""

def __init__(self,
dimensions,
max_threads=None,
parallelizable=True):
super().__init__(max_threads=max_threads,
parallelizable=parallelizable)
def __init__(self, dimensions, max_threads=None, parallelizable=True):
super().__init__(
max_threads=max_threads, parallelizable=parallelizable
)
self.dimensions = dimensions

try:
self.dimensions = np.asarray(self.dimensions, np.float32)
except ValueError:
errmsg = (
f'{self.dimensions} cannot be converted into '
'np.float32 numpy.ndarray'
f"{self.dimensions} cannot be converted into "
"np.float32 numpy.ndarray"
)
raise ValueError(errmsg)
try:
self.dimensions = self.dimensions.reshape(-1, 6)
except ValueError:
errmsg = (
f'{self.dimensions} array does not have valid box '
'dimension shape.\nSimulation box dimensions are '
'given by an float array of shape (6, 0), (1, 6), '
'or (N, 6) where N is the number of frames in the '
'trajectory and the dimension vector(s) containing '
'3 lengths and 3 angles: '
'[a, b, c, alpha, beta, gamma]'
f"{self.dimensions} array does not have valid box "
"dimension shape.\nSimulation box dimensions are "
"given by an float array of shape (6, 0), (1, 6), "
"or (N, 6) where N is the number of frames in the "
"trajectory and the dimension vector(s) containing "
"3 lengths and 3 angles: "
"[a, b, c, alpha, beta, gamma]"
)
raise ValueError(errmsg)

Expand Down
98 changes: 61 additions & 37 deletions package/MDAnalysis/transformations/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,45 +90,57 @@ class fit_translation(TransformationBase):
The transformation was changed to inherit from the base class for
limiting threads and checking if it can be used in parallel analysis.
"""
def __init__(self, ag, reference, plane=None, weights=None,
max_threads=None, parallelizable=True):
super().__init__(max_threads=max_threads,
parallelizable=parallelizable)

def __init__(
self,
ag,
reference,
plane=None,
weights=None,
max_threads=None,
parallelizable=True,
):
super().__init__(
max_threads=max_threads, parallelizable=parallelizable
)

self.ag = ag
self.reference = reference
self.plane = plane
self.weights = weights

if self.plane is not None:
axes = {'yz': 0, 'xz': 1, 'xy': 2}
axes = {"yz": 0, "xz": 1, "xy": 2}
try:
self.plane = axes[self.plane]
except (TypeError, KeyError):
raise ValueError(f'{self.plane} is not a valid plane') \
from None
raise ValueError(
f"{self.plane} is not a valid plane"
) from None
try:
if self.ag.atoms.n_residues != self.reference.atoms.n_residues:
errmsg = (
f"{self.ag} and {self.reference} have mismatched"
f"number of residues"
f"{self.ag} and {self.reference} have mismatched"
f"number of residues"
)

raise ValueError(errmsg)
except AttributeError:
errmsg = (
f"{self.ag} or {self.reference} is not valid"
f"Universe/AtomGroup"
f"{self.ag} or {self.reference} is not valid"
f"Universe/AtomGroup"
)
raise AttributeError(errmsg) from None
self.ref, self.mobile = align.get_matching_atoms(self.reference.atoms,
self.ag.atoms)
self.ref, self.mobile = align.get_matching_atoms(
self.reference.atoms, self.ag.atoms
)
self.weights = align.get_weights(self.ref.atoms, weights=self.weights)
self.ref_com = self.ref.center(self.weights)

def _transform(self, ts):
mobile_com = np.asarray(self.mobile.atoms.center(self.weights),
np.float32)
mobile_com = np.asarray(
self.mobile.atoms.center(self.weights), np.float32
)
vector = self.ref_com - mobile_com
if self.plane is not None:
vector[self.plane] = 0
Expand Down Expand Up @@ -197,23 +209,33 @@ class fit_rot_trans(TransformationBase):
The transformation was changed to inherit from the base class for
limiting threads and checking if it can be used in parallel analysis.
"""
def __init__(self, ag, reference, plane=None, weights=None,
max_threads=1, parallelizable=True):
super().__init__(max_threads=max_threads,
parallelizable=parallelizable)

def __init__(
self,
ag,
reference,
plane=None,
weights=None,
max_threads=1,
parallelizable=True,
):
super().__init__(
max_threads=max_threads, parallelizable=parallelizable
)

self.ag = ag
self.reference = reference
self.plane = plane
self.weights = weights

if self.plane is not None:
axes = {'yz': 0, 'xz': 1, 'xy': 2}
axes = {"yz": 0, "xz": 1, "xy": 2}
try:
self.plane = axes[self.plane]
except (TypeError, KeyError):
raise ValueError(f'{self.plane} is not a valid plane') \
from None
raise ValueError(
f"{self.plane} is not a valid plane"
) from None
try:
if self.ag.atoms.n_residues != self.reference.atoms.n_residues:
errmsg = (
Expand All @@ -223,35 +245,37 @@ def __init__(self, ag, reference, plane=None, weights=None,
raise ValueError(errmsg)
except AttributeError:
errmsg = (
f"{self.ag} or {self.reference} is not valid "
f"Universe/AtomGroup"
f"{self.ag} or {self.reference} is not valid "
f"Universe/AtomGroup"
)
raise AttributeError(errmsg) from None
self.ref, self.mobile = align.get_matching_atoms(self.reference.atoms,
self.ag.atoms)
self.ref, self.mobile = align.get_matching_atoms(
self.reference.atoms, self.ag.atoms
)
self.weights = align.get_weights(self.ref.atoms, weights=self.weights)
self.ref_com = self.ref.center(self.weights)
self.ref_coordinates = self.ref.atoms.positions - self.ref_com

def _transform(self, ts):
mobile_com = self.mobile.atoms.center(self.weights)
mobile_coordinates = self.mobile.atoms.positions - mobile_com
rotation, dump = align.rotation_matrix(mobile_coordinates,
self.ref_coordinates,
weights=self.weights)
rotation, dump = align.rotation_matrix(
mobile_coordinates, self.ref_coordinates, weights=self.weights
)
vector = self.ref_com
if self.plane is not None:
matrix = np.r_[rotation, np.zeros(3).reshape(1, 3)]
matrix = np.c_[matrix, np.zeros(4)]
euler_angs = np.asarray(euler_from_matrix(matrix, axes='sxyz'),
np.float32)
euler_angs = np.asarray(
euler_from_matrix(matrix, axes="sxyz"), np.float32
)
for i in range(0, euler_angs.size):
euler_angs[i] = (euler_angs[self.plane] if i == self.plane
else 0)
rotation = euler_matrix(euler_angs[0],
euler_angs[1],
euler_angs[2],
axes='sxyz')[:3, :3]
euler_angs[i] = (
euler_angs[self.plane] if i == self.plane else 0
)
rotation = euler_matrix(
euler_angs[0], euler_angs[1], euler_angs[2], axes="sxyz"
)[:3, :3]
vector[self.plane] = mobile_com[self.plane]
ts.positions = ts.positions - mobile_com
ts.positions = np.dot(ts.positions, rotation.T)
Expand Down
13 changes: 9 additions & 4 deletions package/MDAnalysis/transformations/nojump.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class NoJump(TransformationBase):
across periodic boundary edges. The algorithm used is based on :footcite:p:`Kulke2022`,
equation B6 for non-orthogonal systems, so it is general to most applications where
molecule trajectories should not "jump" from one side of a periodic box to another.
Note that this transformation depends on a periodic box dimension being set for every
frame in the trajectory, and that this box dimension can be transformed to an orthonormal
unit cell. If not, an error is emitted. Since it is typical to transform all frames
Expand Down Expand Up @@ -133,7 +133,8 @@ def _transform(self, ts):
if (
self.check_c
and self.older_frame != "A"
and (self.old_frame - self.older_frame) != (ts.frame - self.old_frame)
and (self.old_frame - self.older_frame)
!= (ts.frame - self.old_frame)
):
warnings.warn(
"NoJump detected that the interval between frames is unequal."
Expand All @@ -155,7 +156,9 @@ def _transform(self, ts):
)
# Convert into reduced coordinate space
fcurrent = ts.positions @ Linverse
fprev = self.prev # Previous unwrapped coordinates in reduced box coordinates.
fprev = (
self.prev
) # Previous unwrapped coordinates in reduced box coordinates.
# Calculate the new positions in reduced coordinate space (Equation B6 from
# 10.1021/acs.jctc.2c00327). As it turns out, the displacement term can
# be moved inside the round function in this coordinate space, as the
Expand All @@ -164,7 +167,9 @@ def _transform(self, ts):
# Convert back into real space
ts.positions = newpositions @ L
# Set things we need to save for the next frame.
self.prev = newpositions # Note that this is in reduced coordinate space.
self.prev = (
newpositions # Note that this is in reduced coordinate space.
)
self.older_frame = self.old_frame
self.old_frame = ts.frame

Expand Down
Loading

0 comments on commit 557f27d

Please sign in to comment.