From e8fab7a1d8fb395dcd0db32c2911cee9d47bd4b7 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 3 Oct 2024 15:22:09 -0400 Subject: [PATCH] feat: add nequip (#16) Signed-off-by: Jinzhe Zeng --- deepmd_mace/argcheck.py | 153 +++++++ deepmd_mace/nequip.py | 725 +++++++++++++++++++++++++++++++ examples/dprc/nequip/input.json | 55 +++ examples/water/nequip/input.json | 61 +++ pyproject.toml | 2 + tests/test_model.py | 46 +- 6 files changed, 1041 insertions(+), 1 deletion(-) create mode 100644 deepmd_mace/nequip.py create mode 100644 examples/dprc/nequip/input.json create mode 100644 examples/water/nequip/input.json diff --git a/deepmd_mace/argcheck.py b/deepmd_mace/argcheck.py index 248ac0e..0f596ab 100644 --- a/deepmd_mace/argcheck.py +++ b/deepmd_mace/argcheck.py @@ -119,3 +119,156 @@ def mace_model_args() -> Argument: ], doc="MACE model", ) + + +@model_args_plugin.register("nequip") +def nequip_model_args() -> Argument: + """Arguments for the NequIP model.""" + doc_sel = "Maximum number of neighbor atoms." + doc_r_max = "distance cutoff (in Ang)" + doc_num_layers = "number of interaction blocks, we find 3-5 to work best" + doc_l_max = "the maximum irrep order (rotation order) for the network's features, l=1 is a good default, l=2 is more accurate but slower" + doc_num_features = "the multiplicity of the features, 32 is a good default for accurate network, if you want to be more accurate, go larger, if you want to be faster, go lower" + doc_nonlinearity_type = "may be 'gate' or 'norm', 'gate' is recommended" + doc_parity = "whether to include features with odd mirror parityy; often turning parity off gives equally good results but faster networks, so do consider this" + doc_num_basis = ( + "number of basis functions used in the radial basis, 8 usually works best" + ) + doc_besselbasis_trainable = "set true to train the bessel weights" + doc_polynomialcutoff_p = "p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance" + doc_invariant_layers = ( + "number of radial layers, usually 1-3 works best, smaller is faster" + ) + doc_invariant_neurons = ( + "number of hidden neurons in radial function, smaller is faster" + ) + doc_use_sc = "use self-connection or not, usually gives big improvement" + doc_irreps_edge_sh = "irreps for the chemical embedding of species" + doc_feature_irreps_hidden = "irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster" + doc_chemical_embedding_irreps_out = "irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer" + doc_conv_to_output_hidden_irreps_out = "irreps used in hidden layer of output block" + return Argument( + "nequip", + dict, + [ + Argument( + "sel", + [int, str], + optional=False, + doc=doc_sel, + ), + Argument( + "r_max", + float, + optional=True, + default=6.0, + doc=doc_r_max, + ), + Argument( + "num_layers", + int, + optional=True, + default=4, + doc=doc_num_layers, + ), + Argument( + "l_max", + int, + optional=True, + default=2, + doc=doc_l_max, + ), + Argument( + "num_features", + int, + optional=True, + default=32, + doc=doc_num_features, + ), + Argument( + "nonlinearity_type", + str, + optional=True, + default="gate", + doc=doc_nonlinearity_type, + ), + Argument( + "parity", + bool, + optional=True, + default=True, + doc=doc_parity, + ), + Argument( + "num_basis", + int, + optional=True, + default=8, + doc=doc_num_basis, + ), + Argument( + "BesselBasis_trainable", + bool, + optional=True, + default=True, + doc=doc_besselbasis_trainable, + ), + Argument( + "PolynomialCutoff_p", + int, + optional=True, + default=6, + doc=doc_polynomialcutoff_p, + ), + Argument( + "invariant_layers", + int, + optional=True, + default=2, + doc=doc_invariant_layers, + ), + Argument( + "invariant_neurons", + int, + optional=True, + default=64, + doc=doc_invariant_neurons, + ), + Argument( + "use_sc", + bool, + optional=True, + default=True, + doc=doc_use_sc, + ), + Argument( + "irreps_edge_sh", + str, + optional=True, + default="0e + 1e", + doc=doc_irreps_edge_sh, + ), + Argument( + "feature_irreps_hidden", + str, + optional=True, + default="32x0o + 32x0e + 32x1o + 32x1e", + doc=doc_feature_irreps_hidden, + ), + Argument( + "chemical_embedding_irreps_out", + str, + optional=True, + default="32x0e", + doc=doc_chemical_embedding_irreps_out, + ), + Argument( + "conv_to_output_hidden_irreps_out", + str, + optional=True, + default="16x0e", + doc=doc_conv_to_output_hidden_irreps_out, + ), + ], + doc="Nequip model", + ) diff --git a/deepmd_mace/nequip.py b/deepmd_mace/nequip.py new file mode 100644 index 0000000..273cddc --- /dev/null +++ b/deepmd_mace/nequip.py @@ -0,0 +1,725 @@ +"""Nequip model.""" + +import json +from copy import deepcopy +from typing import Any, Optional + +import torch +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + build_neighbor_list, + extend_input_and_build_neighbor_list, +) +from deepmd.pt.utils.region import ( + phys2inter, +) +from deepmd.pt.utils.stat import ( + compute_output_stats, +) +from deepmd.pt.utils.update_sel import ( + UpdateSel, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) +from e3nn.util.jit import ( + script, +) +from nequip.model import model_from_config + + +@BaseModel.register("nequip") +class NequipModel(BaseModel): + """Nequip model. + + Parameters + ---------- + type_map : list[str] + The name of each type of atoms + sel : int + Maximum number of neighbor atoms + r_max : float, optional + distance cutoff (in Ang) + num_layers : int + number of interaction blocks, we find 3-5 to work best + l_max : int + the maximum irrep order (rotation order) for the network's features, l=1 is a good default, l=2 is more accurate but slower + num_features : int + the multiplicity of the features, 32 is a good default for accurate network, if you want to be more accurate, go larger, if you want to be faster, go lower + nonlinearity_type : str + may be 'gate' or 'norm', 'gate' is recommended + parity : bool + whether to include features with odd mirror parityy; often turning parity off gives equally good results but faster networks, so do consider this + num_basis : int + number of basis functions used in the radial basis, 8 usually works best + BesselBasis_trainable : bool + set true to train the bessel weights + PolynomialCutoff_p : int + p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance + invariant_layers : int + number of radial layers, usually 1-3 works best, smaller is faster + invariant_neurons : int + number of hidden neurons in radial function, smaller is faster + use_sc : bool + use self-connection or not, usually gives big improvement + irreps_edge_sh : str + irreps for the chemical embedding of species + feature_irreps_hidden : str + irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster + chemical_embedding_irreps_out : str + irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer + conv_to_output_hidden_irreps_out : str + irreps used in hidden layer of output block + """ + + mm_types: list[int] + e0: torch.Tensor + + def __init__( + self, + type_map: list[str], + sel: int, + r_max: float = 6.0, + num_layers: int = 4, + l_max: int = 2, + num_features: int = 32, + nonlinearity_type: str = "gate", + parity: bool = True, + num_basis: int = 8, + BesselBasis_trainable: bool = True, + PolynomialCutoff_p: int = 6, + invariant_layers: int = 2, + invariant_neurons: int = 64, + use_sc: bool = True, + irreps_edge_sh: str = "0e + 1e", + feature_irreps_hidden: str = "32x0o + 32x0e + 32x1o + 32x1e", + chemical_embedding_irreps_out: str = "32x0e", + conv_to_output_hidden_irreps_out: str = "16x0e", + **kwargs: Any, # noqa: ANN401 + ) -> None: + super().__init__(**kwargs) + self.params = { + "type_map": type_map, + "sel": sel, + "r_max": r_max, + "num_layers": num_layers, + "l_max": l_max, + "num_features": num_features, + "nonlinearity_type": nonlinearity_type, + "parity": parity, + "num_basis": num_basis, + "BesselBasis_trainable": BesselBasis_trainable, + "PolynomialCutoff_p": PolynomialCutoff_p, + "invariant_layers": invariant_layers, + "invariant_neurons": invariant_neurons, + "use_sc": use_sc, + "irreps_edge_sh": irreps_edge_sh, + "feature_irreps_hidden": feature_irreps_hidden, + "chemical_embedding_irreps_out": chemical_embedding_irreps_out, + "conv_to_output_hidden_irreps_out": conv_to_output_hidden_irreps_out, + } + self.type_map = type_map + self.ntypes = len(type_map) + self.preset_out_bias: dict[str, list] = {"energy": []} + self.mm_types = [] + self.sel = sel + self.num_layers = num_layers + for ii, tt in enumerate(type_map): + if not tt.startswith("m") and tt not in {"HW", "OW"}: + self.preset_out_bias["energy"].append(None) + else: + self.preset_out_bias["energy"].append([0]) + self.mm_types.append(ii) + + self.rcut = r_max + self.model = script( + model_from_config( + { + "model_builders": ["EnergyModel"], + "avg_num_neighbors": sel, + "chemical_symbols": type_map, + "num_types": self.ntypes, + "r_max": r_max, + "num_layers": num_layers, + "l_max": l_max, + "num_features": num_features, + "nonlinearity_type": nonlinearity_type, + "parity": parity, + "num_basis": num_basis, + "BesselBasis_trainable": BesselBasis_trainable, + "PolynomialCutoff_p": PolynomialCutoff_p, + "invariant_layers": invariant_layers, + "invariant_neurons": invariant_neurons, + "use_sc": use_sc, + "irreps_edge_sh": irreps_edge_sh, + "feature_irreps_hidden": feature_irreps_hidden, + "chemical_embedding_irreps_out": chemical_embedding_irreps_out, + "conv_to_output_hidden_irreps_out": conv_to_output_hidden_irreps_out, + }, + ), + ) + self.register_buffer( + "e0", + torch.zeros( + self.ntypes, + dtype=env.GLOBAL_PT_ENER_FLOAT_PRECISION, + device=env.DEVICE, + ), + ) + + def compute_or_load_stat( + self, + sampled_func, # noqa: ANN001 + stat_file_path: Optional[DPPath] = None, + ) -> None: + """Compute or load the statistics parameters of the model. + + For example, mean and standard deviation of descriptors or the energy bias of + the fitting net. When `sampled` is provided, all the statistics parameters will + be calculated (or re-calculated for update), and saved in the + `stat_file_path`(s). When `sampled` is not provided, it will check the existence + of `stat_file_path`(s) and load the calculated statistics parameters. + + Parameters + ---------- + sampled_func + The sampled data frames from different data systems. + stat_file_path + The path to the statistics files. + """ + bias_out, _ = compute_output_stats( + sampled_func, + self.get_ntypes(), + keys=["energy"], + stat_file_path=stat_file_path, + rcond=None, + preset_bias=self.preset_out_bias, + ) + if "energy" in bias_out: + self.e0 = ( + bias_out["energy"] + .view(self.e0.shape) + .to(self.e0.dtype) + .to(self.e0.device) + ) + + @torch.jit.export + def fitting_output_def(self) -> FittingOutputDef: + """Get the output def of developer implemented atomic models.""" + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + ], + ) + + @torch.jit.export + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return self.rcut * self.num_layers + + @torch.jit.export + def get_type_map(self) -> list[str]: + """Get the type map.""" + return self.type_map + + @torch.jit.export + def get_sel(self) -> list[int]: + """Return the number of selected atoms for each type.""" + return [self.sel] + + @torch.jit.export + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return 0 + + @torch.jit.export + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return 0 + + @torch.jit.export + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return [] + + @torch.jit.export + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return False + + @torch.jit.export + def mixed_types(self) -> bool: + """Return whether the model is in mixed-types mode. + + If true, the model + 1. assumes total number of atoms aligned across frames; + 2. uses a neighbor list that does not distinguish different atomic types. + If false, the model + 1. assumes total number of atoms of each atom type aligned across frames; + 2. uses a neighbor list that distinguishes different atomic types. + """ + return True + + @torch.jit.export + def has_message_passing(self) -> bool: + """Return whether the descriptor has message passing.""" + return False + + @torch.jit.export + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Forward pass of the model. + + Parameters + ---------- + coord : torch.Tensor + The coordinates of atoms. + atype : torch.Tensor + The atomic types of atoms. + box : torch.Tensor, optional + The box tensor. + fparam : torch.Tensor, optional + The frame parameters. + aparam : torch.Tensor, optional + The atomic parameters. + do_atomic_virial : bool, optional + Whether to compute atomic virial. + """ + nloc = atype.shape[1] + extended_coord, extended_atype, mapping, nlist = ( + extend_input_and_build_neighbor_list( + coord, + atype, + self.rcut, + self.get_sel(), + mixed_types=True, + box=box, + ) + ) + model_ret_lower = self.forward_lower_common( + nloc, + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=None, + box=box, + ) + model_ret = communicate_extended_output( + model_ret_lower, + ModelOutputDef(self.fitting_output_def()), + mapping, + do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ) -> dict[str, torch.Tensor]: + """Forward lower pass of the model. + + Parameters + ---------- + extended_coord : torch.Tensor + The extended coordinates of atoms. + extended_atype : torch.Tensor + The extended atomic types of atoms. + nlist : torch.Tensor + The neighbor list. + mapping : torch.Tensor, optional + The mapping tensor. + fparam : torch.Tensor, optional + The frame parameters. + aparam : torch.Tensor, optional + The atomic parameters. + do_atomic_virial : bool, optional + Whether to compute atomic virial. + comm_dict : dict[str, torch.Tensor], optional + The communication dictionary. + """ + nloc = nlist.shape[1] + nf, nall = extended_atype.shape + # recalculate nlist for ghost atoms + if self.num_layers > 1 and nloc < nall: + nlist = build_neighbor_list( + extended_coord.view(nf, -1), + extended_atype, + nall, + self.rcut * self.num_layers, + self.sel, + distinguish_types=False, + ) + model_ret = self.forward_lower_common( + nloc, + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + do_atomic_virial, + comm_dict, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(-3) + return model_predict + + def forward_lower_common( + self, + nloc: int, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, # noqa: ARG002 + comm_dict: Optional[dict[str, torch.Tensor]] = None, + box: Optional[torch.Tensor] = None, + ) -> dict[str, torch.Tensor]: + """Forward lower common pass of the model. + + Parameters + ---------- + extended_coord : torch.Tensor + The extended coordinates of atoms. + extended_atype : torch.Tensor + The extended atomic types of atoms. + nlist : torch.Tensor + The neighbor list. + mapping : torch.Tensor, optional + The mapping tensor. + fparam : torch.Tensor, optional + The frame parameters. + aparam : torch.Tensor, optional + The atomic parameters. + do_atomic_virial : bool, optional + Whether to compute atomic virial. + comm_dict : dict[str, torch.Tensor], optional + The communication dictionary. + box : torch.Tensor, optional + The box tensor. + """ + nf, nall = extended_atype.shape + + extended_coord = extended_coord.view(nf, nall, 3) + extended_coord_ = extended_coord + if fparam is not None: + msg = "fparam is unsupported" + raise ValueError(msg) + if aparam is not None: + msg = "aparam is unsupported" + raise ValueError(msg) + if comm_dict is not None: + msg = "comm_dict is unsupported" + raise ValueError(msg) + nlist = nlist.to(torch.int64) + extended_atype = extended_atype.to(torch.int64) + nall = extended_coord.shape[1] + + # loop on nf + energies = [] + forces = [] + virials = [] + atom_energies = [] + atomic_virials = [] + for ff in range(nf): + extended_coord_ff = extended_coord[ff] + extended_atype_ff = extended_atype[ff] + nlist_ff = nlist[ff] + edge_index = torch.ops.deepmd_mace.mace_edge_index( + nlist_ff, + extended_atype_ff, + torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"), + ) + edge_index = edge_index.T + # Nequip and MACE have different defination for edge_index + edge_index = edge_index[[1, 0]] + + # nequip can convert dtype by itself + default_dtype = torch.float64 + extended_coord_ff = extended_coord_ff.to(default_dtype) + extended_coord_ff.requires_grad_(True) # noqa: FBT003 + + input_dict = { + "pos": extended_coord_ff, + "edge_index": edge_index, + "atom_types": extended_atype_ff, + } + if box is not None and mapping is not None: + # pass box, map edge index to real + box_ff = box[ff].to(extended_coord_ff.device) + input_dict["cell"] = box_ff + input_dict["pbc"] = torch.zeros( + 3, + dtype=torch.bool, + device=box_ff.device, + ) + shifts_atoms = extended_coord_ff - extended_coord_ff[mapping[ff]] + shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]] + edge_index = mapping[ff][edge_index] + input_dict["edge_index"] = edge_index + edge_cell_shift = phys2inter(shifts, box_ff.view(3, 3)) + input_dict["edge_cell_shift"] = edge_cell_shift + + ret = self.model.forward( + input_dict, + ) + + atom_energy = ret["atomic_energy"] + if atom_energy is None: + msg = "atom_energy is None" + raise ValueError(msg) + atom_energy = atom_energy.view(1, nall).to(extended_coord_.dtype)[:, :nloc] + # adds e0 + atom_energy = atom_energy + self.e0[extended_atype_ff[:nloc]].view( + 1, + nloc, + ).to( + atom_energy.dtype, + ) + energy = torch.sum(atom_energy, dim=1).view(1, 1).to(extended_coord_.dtype) + grad_outputs: list[Optional[torch.Tensor]] = [ + torch.ones_like(energy), + ] + force = torch.autograd.grad( + outputs=[energy], + inputs=[extended_coord_ff], + grad_outputs=grad_outputs, + retain_graph=True, + create_graph=self.training, + )[0] + if force is None: + msg = "force is None" + raise ValueError(msg) + force = -force + atomic_virial = force.unsqueeze(-1).to( + extended_coord_.dtype, + ) @ extended_coord_ff.unsqueeze(-2).to( + extended_coord_.dtype, + ) + force = force.view(1, nall, 3).to(extended_coord_.dtype) + virial = ( + torch.sum(atomic_virial, dim=0).view(1, 9).to(extended_coord_.dtype) + ) + + energies.append(energy) + forces.append(force) + virials.append(virial) + atom_energies.append(atom_energy) + atomic_virials.append(atomic_virial) + energies_t = torch.cat(energies, dim=0) + forces_t = torch.cat(forces, dim=0) + virials_t = torch.cat(virials, dim=0) + atom_energies_t = torch.cat(atom_energies, dim=0) + atomic_virials_t = torch.cat(atomic_virials, dim=0) + + return { + "energy_redu": energies_t.view(nf, 1), + "energy_derv_r": forces_t.view(nf, nall, 1, 3), + "energy_derv_c_redu": virials_t.view(nf, 1, 9), + # take the first nloc atoms to match other models + "energy": atom_energies_t.view(nf, nloc, 1), + # fake atom_virial + "energy_derv_c": atomic_virials_t.view(nf, nall, 1, 9), + } + + def serialize(self) -> dict: + """Serialize the model.""" + return { + "@class": "Model", + "@version": 1, + "type": "mace", + **self.params, + "@variables": { + **{ + kk: to_numpy_array(vv) for kk, vv in self.model.state_dict().items() + }, + "e0": to_numpy_array(self.e0), + }, + } + + @classmethod + def deserialize(cls, data: dict) -> "NequipModel": + """Deserialize the model.""" + data = data.copy() + if not (data.pop("@class") == "Model" and data.pop("type") == "mace"): + msg = "data is not a serialized NequipModel" + raise ValueError(msg) + check_version_compatibility(data.pop("@version"), 1, 1) + variables = { + kk: to_torch_tensor(vv) for kk, vv in data.pop("@variables").items() + } + model = cls(**data) + model.e0 = variables.pop("e0") + model.model.load_state_dict(variables) + return model + + @torch.jit.export + def get_nnei(self) -> int: + """Return the total number of selected neighboring atoms in cut-off radius.""" + return self.sel + + @torch.jit.export + def get_nsel(self) -> int: + """Return the total number of selected neighboring atoms in cut-off radius.""" + return self.sel + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statictics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + min_nbor_dist, sel = UpdateSel().update_one_sel( + train_data, + type_map, + local_jdata_cpy["r_max"], + local_jdata_cpy["sel"], + mixed_type=True, + ) + local_jdata_cpy["sel"] = sel[0] + return local_jdata_cpy, min_nbor_dist + + @torch.jit.export + def model_output_type(self) -> list[str]: + """Get the output type for the model.""" + return ["energy"] + + def translated_output_def(self) -> dict[str, Any]: + """Get the translated output def for the model.""" + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": deepcopy(out_def_data["energy"]), + "energy": deepcopy(out_def_data["energy_redu"]), + } + output_def["force"] = deepcopy(out_def_data["energy_derv_r"]) + output_def["force"].squeeze(-2) + output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) + output_def["atom_virial"].squeeze(-3) + if "mask" in out_def_data: + output_def["mask"] = deepcopy(out_def_data["mask"]) + return output_def + + def model_output_def(self) -> ModelOutputDef: + """Get the output def for the model.""" + return ModelOutputDef(self.fitting_output_def()) + + @classmethod + def get_model(cls, model_params: dict) -> "NequipModel": + """Get the model by the parameters. + + Parameters + ---------- + model_params : dict + The model parameters + + Returns + ------- + BaseBaseModel + The model + """ + model_params_old = model_params.copy() + model_params = model_params.copy() + model_params.pop("type", None) + precision = model_params.pop("precision", "float32") + if precision == "float32": + torch.set_default_dtype(torch.float32) + elif precision == "float64": + torch.set_default_dtype(torch.float64) + else: + msg = f"precision {precision} not supported" + raise ValueError(msg) + model = cls(**model_params) + model.model_def_script = json.dumps(model_params_old) + return model diff --git a/examples/dprc/nequip/input.json b/examples/dprc/nequip/input.json new file mode 100644 index 0000000..d9a8655 --- /dev/null +++ b/examples/dprc/nequip/input.json @@ -0,0 +1,55 @@ +{ + "_comment1": " model parameters", + "model": { + "type": "nequip", + "type_map": [ + "C", + "P", + "O", + "H", + "OW", + "HW" + ], + "r_max": 6.0, + "sel": "auto", + "l_max": 1, + "_comment2": " that's all" + }, + + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment3": "that's all" + }, + + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment4": " that's all" + }, + + "training": { + "training_data": { + "systems": [ + "../data" + ], + "batch_size": "auto", + "_comment4": "that's all" + }, + "numb_steps": 100000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + "_comment5": "that's all" + }, + + "_comment8": "that's all" +} diff --git a/examples/water/nequip/input.json b/examples/water/nequip/input.json new file mode 100644 index 0000000..5845ec8 --- /dev/null +++ b/examples/water/nequip/input.json @@ -0,0 +1,61 @@ +{ + "_comment1": " model parameters", + "model": { + "type": "nequip", + "type_map": [ + "O", + "H" + ], + "r_max": 6.0, + "sel": "auto", + "l_max": 1, + "_comment2": " that's all" + }, + + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment3": "that's all" + }, + + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment4": " that's all" + }, + + "training": { + "training_data": { + "systems": [ + "../data/data_0/", + "../data/data_1/", + "../data/data_2/" + ], + "batch_size": "auto", + "_comment5": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment6": "that's all" + }, + "numb_steps": 1000000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000, + "_comment7": "that's all" + }, + + "_comment8": "that's all" +} diff --git a/pyproject.toml b/pyproject.toml index d20ac2e..643d36b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "torch", "deepmd-kit[torch]>=3.0.0b2", "mace-torch>=0.3.5", + "nequip", "e3nn", "dargs", ] @@ -38,6 +39,7 @@ keywords = [ [project.entry-points."deepmd.pt"] mace = "deepmd_mace.mace:MaceModel" +nequip = "deepmd_mace.nequip:NequipModel" [project.urls] repository = "https://github.com/njzjz/deepmd_mace" diff --git a/tests/test_model.py b/tests/test_model.py index b4fd8e7..28ef289 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -25,6 +25,7 @@ ) from deepmd_mace.mace import MaceModel +from deepmd_mace.nequip import NequipModel GLOBAL_SEED = 20240822 @@ -141,6 +142,8 @@ class ModelTestCase: """Expected number of neighbors.""" expected_has_message_passing: bool """Expected whether having message passing.""" + expected_nmpnn: int + """Expected number of MPNN.""" forward_wrapper: ClassVar[Callable[[Any, bool], Any]] """Class wrapper for forward method.""" forward_wrapper_cpu_ref: Callable[[Any], Any] @@ -165,7 +168,7 @@ def test_get_type_map(self) -> None: def test_get_rcut(self) -> None: """Test get_rcut.""" for module in self.modules_to_test: - assert module.get_rcut() == self.expected_rcut * 2 + assert module.get_rcut() == self.expected_rcut * self.expected_nmpnn def test_get_dim_fparam(self) -> None: """Test get_dim_fparam.""" @@ -1070,3 +1073,44 @@ def setUpClass(cls) -> None: cls.expected_sel_type = [] cls.expected_dim_fparam = 0 cls.expected_dim_aparam = 0 + cls.expected_nmpnn = 2 + + +class TestNequipModel(unittest.TestCase, EnerModelTest, PTTestCase): # type: ignore[misc] + """Test Nequip model.""" + + @property + def modules_to_test(self) -> list[torch.nn.Module]: # type: ignore[override] + """Modules to test.""" + skip_test_jit = getattr(self, "skip_test_jit", False) + modules = PTTestCase.modules_to_test.fget(self) # type: ignore[attr-defined] + if not skip_test_jit: + # for Model, we can test script module API + modules += [ + self._script_module + if hasattr(self, "_script_module") + else self.script_module, + ] + return modules + + _script_module: torch.jit.ScriptModule + + @classmethod + def setUpClass(cls) -> None: + """Set up class.""" + EnerModelTest.setUpClass() + + cls.module = NequipModel( + type_map=cls.expected_type_map, + sel=138, + r_max=cls.expected_rcut, + num_layers=2, + ) + with torch.jit.optimized_execution(should_optimize=False): + cls._script_module = torch.jit.script(cls.module) + cls.output_def = cls.module.translated_output_def() + cls.expected_has_message_passing = False + cls.expected_sel_type = [] + cls.expected_dim_fparam = 0 + cls.expected_dim_aparam = 0 + cls.expected_nmpnn = 2