From c0914e1ebc18d3bb7bc76571bd18ca2f8c701586 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu <53895049+ChiahsinChu@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:25:33 +0800 Subject: [PATCH] feat (tf/pt): add atomic weights to tensor loss (#4466) Interfaces are of particular interest in many studies. However, the configurations in the training set to represent the interface normally also include large parts of the bulk material. As a result, the final model would prefer the bulk information while the interfacial information is less learnt. It is difficult to simply improve the proportion of interfaces in the configurations since the electronic structures of the interface might only be reasonable with a certain thickness of bulk materials. Therefore, I wonder whether it is possible to define weights for atomic quantities in loss functions. This allows us to add higher weights for the atomic information for the regions of interest and probably makes the model "more focused" on the region of interest. In this PR, I add the keyword `enable_atomic_weight` to the loss function of the tensor model. In principle, it could be generalised to any atomic quantity, e.g., atomic forces. I would like to know the developers' comments/suggestions about this feature. I can add support for other loss functions and finish unit tests once we agree on this feature. Best. ## Summary by CodeRabbit - **New Features** - Introduced an optional parameter for atomic weights in loss calculations, enhancing flexibility in the `TensorLoss` class. - Added a suite of unit tests for the `TensorLoss` functionality, ensuring consistency between TensorFlow and PyTorch implementations. - **Bug Fixes** - Updated logic for local loss calculations to ensure correct application of atomic weights based on user input. - **Documentation** - Improved clarity of documentation for several function arguments, including the addition of a new argument related to atomic weights. --- deepmd/pt/loss/tensor.py | 22 ++ deepmd/tf/loss/tensor.py | 24 +- deepmd/utils/argcheck.py | 12 +- source/tests/pt/test_loss_tensor.py | 464 ++++++++++++++++++++++++++++ 4 files changed, 519 insertions(+), 3 deletions(-) create mode 100644 source/tests/pt/test_loss_tensor.py diff --git a/deepmd/pt/loss/tensor.py b/deepmd/pt/loss/tensor.py index 8f2f937a07..69b133de58 100644 --- a/deepmd/pt/loss/tensor.py +++ b/deepmd/pt/loss/tensor.py @@ -22,6 +22,7 @@ def __init__( pref_atomic: float = 0.0, pref: float = 0.0, inference=False, + enable_atomic_weight: bool = False, **kwargs, ) -> None: r"""Construct a loss for local and global tensors. @@ -40,6 +41,8 @@ def __init__( The prefactor of the weight of global loss. It should be larger than or equal to 0. inference : bool If true, it will output all losses found in output, ignoring the pre-factors. + enable_atomic_weight : bool + If true, atomic weight will be used in the loss calculation. **kwargs Other keyword arguments. """ @@ -50,6 +53,7 @@ def __init__( self.local_weight = pref_atomic self.global_weight = pref self.inference = inference + self.enable_atomic_weight = enable_atomic_weight assert ( self.local_weight >= 0.0 and self.global_weight >= 0.0 @@ -85,6 +89,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False """ model_pred = model(**input_dict) del learning_rate, mae + + if self.enable_atomic_weight: + atomic_weight = label["atom_weight"].reshape([-1, 1]) + else: + atomic_weight = 1.0 + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] more_loss = {} if ( @@ -103,6 +113,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False diff = (local_tensor_pred - local_tensor_label).reshape( [-1, self.tensor_size] ) + diff = diff * atomic_weight if "mask" in model_pred: diff = diff[model_pred["mask"].reshape([-1]).bool()] l2_local_loss = torch.mean(torch.square(diff)) @@ -171,4 +182,15 @@ def label_requirement(self) -> list[DataRequirementItem]: high_prec=False, ) ) + if self.enable_atomic_weight: + label_requirement.append( + DataRequirementItem( + "atomic_weight", + ndof=1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + ) + ) return label_requirement diff --git a/deepmd/tf/loss/tensor.py b/deepmd/tf/loss/tensor.py index aca9182ff6..d7f879b4b4 100644 --- a/deepmd/tf/loss/tensor.py +++ b/deepmd/tf/loss/tensor.py @@ -40,6 +40,7 @@ def __init__(self, jdata, **kwarg) -> None: # YWolfeee: modify, use pref / pref_atomic, instead of pref_weight / pref_atomic_weight self.local_weight = jdata.get("pref_atomic", None) self.global_weight = jdata.get("pref", None) + self.enable_atomic_weight = jdata.get("enable_atomic_weight", False) assert ( self.local_weight is not None and self.global_weight is not None @@ -66,9 +67,18 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix): "global_loss": global_cvt_2_tf_float(0.0), } + if self.enable_atomic_weight: + atomic_weight = tf.reshape(label_dict["atom_weight"], [-1, 1]) + else: + atomic_weight = global_cvt_2_tf_float(1.0) + if self.local_weight > 0.0: + diff = tf.reshape(polar, [-1, self.tensor_size]) - tf.reshape( + atomic_polar_hat, [-1, self.tensor_size] + ) + diff = diff * atomic_weight local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean( - tf.square(self.scale * (polar - atomic_polar_hat)), name="l2_" + suffix + tf.square(self.scale * diff), name="l2_" + suffix ) more_loss["local_loss"] = self.display_if_exist(local_loss, find_atomic) l2_loss += self.local_weight * local_loss @@ -163,4 +173,16 @@ def label_requirement(self) -> list[DataRequirementItem]: type_sel=self.type_sel, ) ) + if self.enable_atomic_weight: + data_requirements.append( + DataRequirementItem( + "atom_weight", + 1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + type_sel=self.type_sel, + ) + ) return data_requirements diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 5b57f15979..9eac0e804d 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2511,8 +2511,9 @@ def loss_property(): def loss_tensor(): # doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If only `pref` is provided or both are not provided, training will be global mode, i.e. the shape of 'polarizability.npy` or `dipole.npy` should be #frams x [9 or 3]." # doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If only `pref_atomic` is provided, training will be atomic mode, i.e. the shape of `polarizability.npy` or `dipole.npy` should be #frames x ([9 or 3] x #selected atoms). If both `pref` and `pref_atomic` are provided, training will be combined mode, and atomic label should be provided as well." - doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included." - doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #selected atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0." + doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included." + doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0." + doc_enable_atomic_weight = "If true, the atomic loss will be reweighted." return [ Argument( "pref", [float, int], optional=False, default=None, doc=doc_global_weight @@ -2524,6 +2525,13 @@ def loss_tensor(): default=None, doc=doc_local_weight, ), + Argument( + "enable_atomic_weight", + bool, + optional=True, + default=False, + doc=doc_enable_atomic_weight, + ), ] diff --git a/source/tests/pt/test_loss_tensor.py b/source/tests/pt/test_loss_tensor.py new file mode 100644 index 0000000000..5802c0b775 --- /dev/null +++ b/source/tests/pt/test_loss_tensor.py @@ -0,0 +1,464 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +import unittest + +import numpy as np +import tensorflow.compat.v1 as tf +import torch + +tf.disable_eager_execution() +from pathlib import ( + Path, +) + +from deepmd.pt.loss import TensorLoss as PTTensorLoss +from deepmd.pt.utils import ( + dp_random, + env, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSetForLoader, +) +from deepmd.tf.loss.tensor import TensorLoss as TFTensorLoss +from deepmd.utils.data import ( + DataRequirementItem, +) + +from ..seed import ( + GLOBAL_SEED, +) + +CUR_DIR = os.path.dirname(__file__) + + +def get_batch(system, type_map, data_requirement): + dataset = DeepmdDataSetForLoader(system, type_map) + dataset.add_data_requirement(data_requirement) + np_batch, pt_batch = get_single_batch(dataset) + return np_batch, pt_batch + + +def get_single_batch(dataset, index=None): + if index is None: + index = dp_random.choice(np.arange(len(dataset))) + np_batch = dataset[index] + pt_batch = {} + + for key in [ + "coord", + "box", + "atom_dipole", + "dipole", + "atom_polarizability", + "polarizability", + "atype", + "natoms", + ]: + if key in np_batch.keys(): + np_batch[key] = np.expand_dims(np_batch[key], axis=0) + pt_batch[key] = torch.as_tensor(np_batch[key], device=env.DEVICE) + if key in ["coord", "atom_dipole"]: + np_batch[key] = np_batch[key].reshape(1, -1) + np_batch["natoms"] = np_batch["natoms"][0] + return np_batch, pt_batch + + +class LossCommonTest(unittest.TestCase): + def setUp(self) -> None: + self.cur_lr = 1.2 + self.type_map = ["H", "O"] + + # data + tensor_data_requirement = [ + DataRequirementItem( + "atomic_" + self.label_name, + ndof=self.tensor_size, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + self.label_name, + ndof=self.tensor_size, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atomic_weight", + ndof=1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + ), + ] + np_batch, pt_batch = get_batch( + self.system, self.type_map, tensor_data_requirement + ) + natoms = np_batch["natoms"] + self.nloc = natoms[0] + self.nframes = np_batch["atom_" + self.label_name].shape[0] + rng = np.random.default_rng(GLOBAL_SEED) + + l_atomic_tensor, l_global_tensor = ( + np_batch["atom_" + self.label_name], + np_batch[self.label_name], + ) + p_atomic_tensor, p_global_tensor = ( + np.ones_like(l_atomic_tensor), + np.ones_like(l_global_tensor), + ) + + batch_size = pt_batch["coord"].shape[0] + + # atom_pref = rng.random(size=[batch_size, nloc * 3]) + # drdq = rng.random(size=[batch_size, nloc * 2 * 3]) + atom_weight = rng.random(size=[batch_size, self.nloc]) + + # tf + self.g = tf.Graph() + with self.g.as_default(): + t_cur_lr = tf.placeholder(shape=[], dtype=tf.float64) + t_natoms = tf.placeholder(shape=[None], dtype=tf.int32) + t_patomic_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_pglobal_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_latomic_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_lglobal_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_atom_weight = tf.placeholder(shape=[None, None], dtype=tf.float64) + find_atomic = tf.constant(1.0, dtype=tf.float64) + find_global = tf.constant(1.0, dtype=tf.float64) + find_atom_weight = tf.constant(1.0, dtype=tf.float64) + model_dict = { + self.tensor_name: t_patomic_tensor, + } + label_dict = { + "atom_" + self.label_name: t_latomic_tensor, + "find_atom_" + self.label_name: find_atomic, + self.label_name: t_lglobal_tensor, + "find_" + self.label_name: find_global, + "atom_weight": t_atom_weight, + "find_atom_weight": find_atom_weight, + } + self.tf_loss_sess = self.tf_loss.build( + t_cur_lr, t_natoms, model_dict, label_dict, "" + ) + + self.feed_dict = { + t_cur_lr: self.cur_lr, + t_natoms: natoms, + t_patomic_tensor: p_atomic_tensor, + t_pglobal_tensor: p_global_tensor, + t_latomic_tensor: l_atomic_tensor, + t_lglobal_tensor: l_global_tensor, + t_atom_weight: atom_weight, + } + # pt + self.model_pred = { + self.tensor_name: torch.from_numpy(p_atomic_tensor), + "global_" + self.tensor_name: torch.from_numpy(p_global_tensor), + } + self.label = { + "atom_" + self.label_name: torch.from_numpy(l_atomic_tensor), + "find_" + "atom_" + self.label_name: 1.0, + self.label_name: torch.from_numpy(l_global_tensor), + "find_" + self.label_name: 1.0, + "atom_weight": torch.from_numpy(atom_weight), + "find_atom_weight": 1.0, + } + self.label_absent = { + "atom_" + self.label_name: torch.from_numpy(l_atomic_tensor), + self.label_name: torch.from_numpy(l_global_tensor), + "atom_weight": torch.from_numpy(atom_weight), + } + self.natoms = pt_batch["natoms"] + + def tearDown(self) -> None: + tf.reset_default_graph() + return super().tearDown() + + +class TestAtomicDipoleLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "dipole" + self.tensor_size = 3 + self.label_name = "dipole" + self.system = str(Path(__file__).parent / "water_tensor/dipole/O78H156") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +class TestAtomicDipoleAWeightLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "dipole" + self.tensor_size = 3 + self.label_name = "dipole" + self.system = str(Path(__file__).parent / "water_tensor/dipole/O78H156") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + "enable_atomic_weight": True, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + enable_atomic_weight=True, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +class TestAtomicPolarLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "polar" + self.tensor_size = 9 + self.label_name = "polarizability" + + self.system = str(Path(__file__).parent / "water_tensor/polar/atomic_system") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +class TestAtomicPolarAWeightLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "polar" + self.tensor_size = 9 + self.label_name = "polarizability" + + self.system = str(Path(__file__).parent / "water_tensor/polar/atomic_system") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + "enable_atomic_weight": True, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + enable_atomic_weight=True, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +if __name__ == "__main__": + unittest.main()