Skip to content

Commit

Permalink
Merge pull request #1641 from amcadmus/master
Browse files Browse the repository at this point in the history
Merge devel into master
  • Loading branch information
amcadmus authored Apr 16, 2022
2 parents 3e54fea + 4286937 commit c4f0cec
Show file tree
Hide file tree
Showing 55 changed files with 1,526 additions and 221 deletions.
9 changes: 6 additions & 3 deletions deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,16 +351,19 @@ def get_feed_dict(self,
return feed_dict

def init_variables(self,
model_file: str,
graph: tf.Graph,
graph_def: tf.GraphDef,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict
Parameters
----------
model_file : str
The input model file
graph : tf.Graph
The input frozen model graph
graph_def : tf.GraphDef
The input frozen model graph_def
suffix : str, optional
The suffix of the scope
Expand Down
11 changes: 7 additions & 4 deletions deepmd/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,21 +279,24 @@ def enable_mixed_precision(self, mixed_prec : dict = None) -> None:


def init_variables(self,
model_file : str,
graph: tf.Graph,
graph_def: tf.GraphDef,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict
Parameters
----------
model_file : str
The input frozen model file
graph : tf.Graph
The input frozen model graph
graph_def : tf.GraphDef
The input frozen model graph_def
suffix : str, optional
The suffix of the scope
"""
for idx, ii in enumerate(self.descrpt_list):
ii.init_variables(model_file, suffix=f"{suffix}_{idx}")
ii.init_variables(graph, graph_def, suffix=f"{suffix}_{idx}")

def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
"""Get names of tensors.
Expand Down
17 changes: 10 additions & 7 deletions deepmd/descriptor/loc_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from deepmd.env import default_tf_session_config
from deepmd.utils.sess import run_sess
from .descriptor import Descriptor
from deepmd.utils.graph import get_tensor_by_name
from deepmd.utils.graph import get_tensor_by_name_from_graph

@Descriptor.register("loc_frame")
class DescrptLocFrame (Descriptor) :
Expand Down Expand Up @@ -369,18 +369,21 @@ def _compute_std (self,sumv2, sumv, sumn) :
return np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn))

def init_variables(self,
model_file : str,
graph: tf.Graph,
graph_def: tf.GraphDef,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given frozen model
Init the embedding net variables with the given dict
Parameters
----------
model_file : str
The input frozen model file
graph : tf.Graph
The input frozen model graph
graph_def : tf.GraphDef
The input frozen model graph_def
suffix : str, optional
The suffix of the scope
"""
self.davg = get_tensor_by_name(model_file, 'descrpt_attr%s/t_avg' % suffix)
self.tavg = get_tensor_by_name(model_file, 'descrpt_attr%s/t_std' % suffix)
self.davg = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/t_avg' % suffix)
self.dstd = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/t_std' % suffix)
19 changes: 11 additions & 8 deletions deepmd/descriptor/se.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Tuple, List

from deepmd.env import tf
from deepmd.utils.graph import get_embedding_net_variables, get_tensor_by_name
from deepmd.utils.graph import get_embedding_net_variables_from_graph_def, get_tensor_by_name_from_graph
from .descriptor import Descriptor


Expand Down Expand Up @@ -92,22 +92,25 @@ def pass_tensors_from_frz_model(self,
self.descrpt_reshape = descrpt_reshape

def init_variables(self,
model_file : str,
graph: tf.Graph,
graph_def: tf.GraphDef,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given frozen model
Init the embedding net variables with the given dict
Parameters
----------
model_file : str
The input frozen model file
graph : tf.Graph
The input frozen model graph
graph_def : tf.GraphDef
The input frozen model graph_def
suffix : str, optional
The suffix of the scope
"""
self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix)
self.davg = get_tensor_by_name(model_file, 'descrpt_attr%s/t_avg' % suffix)
self.tavg = get_tensor_by_name(model_file, 'descrpt_attr%s/t_std' % suffix)
self.embedding_net_variables = get_embedding_net_variables_from_graph_def(graph_def, suffix = suffix)
self.davg = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/t_avg' % suffix)
self.dstd = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/t_std' % suffix)

@property
def precision(self) -> tf.DType:
Expand Down
86 changes: 79 additions & 7 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from deepmd.utils.tabulate import DPTabulate
from deepmd.utils.type_embed import embed_atom_type
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_tensor_by_name
from deepmd.utils.errors import GraphWithoutTensorError
from .descriptor import Descriptor
from .se import DescrptSe

Expand Down Expand Up @@ -142,8 +143,6 @@ def __init__ (self,
self.exclude_types.add((tt[1], tt[0]))
self.set_davg_zero = set_davg_zero
self.type_one_side = type_one_side
if self.type_one_side and len(exclude_types) != 0:
raise RuntimeError('"type_one_side" is not compatible with "exclude_types"')

# descrpt config
self.sel_r = [ 0 for ii in range(len(self.sel_a)) ]
Expand Down Expand Up @@ -193,6 +192,7 @@ def __init__ (self,
sel_a = self.sel_a,
sel_r = self.sel_r)
self.sub_sess = tf.Session(graph = sub_graph, config=default_tf_session_config)
self.original_sel = None


def get_rcut (self) -> float:
Expand Down Expand Up @@ -431,6 +431,9 @@ def build (self,
t_sel = tf.constant(self.sel_a,
name = 'sel',
dtype = tf.int32)
t_original_sel = tf.constant(self.original_sel if self.original_sel is not None else self.sel_a,
name = 'original_sel',
dtype = tf.int32)
self.t_avg = tf.get_variable('t_avg',
davg.shape,
dtype = GLOBAL_TF_FLOAT_PRECISION,
Expand All @@ -442,7 +445,8 @@ def build (self,
trainable = False,
initializer = tf.constant_initializer(dstd))

coord = tf.reshape (coord_, [-1, natoms[1] * 3])
with tf.control_dependencies([t_sel, t_original_sel]):
coord = tf.reshape (coord_, [-1, natoms[1] * 3])
box = tf.reshape (box_, [-1, 9])
atype = tf.reshape (atype_, [-1, natoms[1]])

Expand Down Expand Up @@ -552,13 +556,19 @@ def _pass_filter(self,
inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
output = []
output_qmat = []
if not self.type_one_side and type_embedding is None:
if not (self.type_one_side and len(self.exclude_types) == 0) and type_embedding is None:
for type_i in range(self.ntypes):
inputs_i = tf.slice (inputs,
[ 0, start_index* self.ndescrpt],
[-1, natoms[2+type_i]* self.ndescrpt] )
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
layer, qmat = self._filter(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
if self.type_one_side:
# reuse NN parameters for all types to support type_one_side along with exclude_types
reuse = tf.AUTO_REUSE
filter_name = 'filter_type_all'+suffix
else:
filter_name = 'filter_type_'+str(type_i)+suffix
layer, qmat = self._filter(inputs_i, type_i, name=filter_name, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3])
output.append(layer)
Expand Down Expand Up @@ -823,7 +833,12 @@ def _filter(
# inputs_reshape = tf.reshape(inputs, [-1, shape[1]//4, 4])
# natom x 4 x outputs_size
# xyz_scatter_1 = tf.matmul(inputs_reshape, xyz_scatter, transpose_a = True)
xyz_scatter_1 = xyz_scatter_1 * (4.0 / shape[1])
if self.original_sel is None:
# shape[1] = nnei * 4
nnei = shape[1] / 4
else:
nnei = tf.cast(tf.Variable(np.sum(self.original_sel), dtype=tf.int32, trainable=False, name="nnei"), self.filter_precision)
xyz_scatter_1 = xyz_scatter_1 / nnei
# natom x 4 x outputs_size_2
xyz_scatter_2 = tf.slice(xyz_scatter_1, [0,0,0],[-1,-1,outputs_size_2])
# # natom x 3 x outputs_size_2
Expand All @@ -838,3 +853,60 @@ def _filter(
result = tf.reshape(result, [-1, outputs_size_2 * outputs_size[-1]])

return result, qmat

def init_variables(self,
graph: tf.Graph,
graph_def: tf.GraphDef,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict
Parameters
----------
graph : tf.Graph
The input frozen model graph
graph_def : tf.GraphDef
The input frozen model graph_def
suffix : str, optional
The suffix of the scope
"""
super().init_variables(graph=graph, graph_def=graph_def, suffix=suffix)
try:
self.original_sel = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/original_sel' % suffix)
except GraphWithoutTensorError:
# original_sel is not restored in old graphs, assume sel never changed before
pass
# check sel == original sel?
try:
sel = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/sel' % suffix)
except GraphWithoutTensorError:
# sel is not restored in old graphs
pass
else:
if not np.array_equal(np.array(self.sel_a), sel):
if not self.set_davg_zero:
raise RuntimeError("Adjusting sel is only supported when `set_davg_zero` is true!")
# as set_davg_zero, self.davg is safely zero
self.davg = np.zeros([self.ntypes, self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
new_dstd = np.ones([self.ntypes, self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
# shape of davg and dstd is (ntypes, ndescrpt), ndescrpt = 4*sel
n_descpt = np.array(self.sel_a) * 4
n_descpt_old = np.array(sel) * 4
end_index = np.cumsum(n_descpt)
end_index_old = np.cumsum(n_descpt_old)
start_index = np.roll(end_index, 1)
start_index[0] = 0
start_index_old = np.roll(end_index_old, 1)
start_index_old[0] = 0

for nn, oo, ii, jj in zip(n_descpt, n_descpt_old, start_index, start_index_old):
if nn < oo:
# new size is smaller, copy part of std
new_dstd[:, ii:ii+nn] = self.dstd[:, jj:jj+nn]
else:
# new size is larger, copy all, the rest remains 1
new_dstd[:, ii:ii+oo] = self.dstd[:, jj:jj+oo]
self.dstd = new_dstd
if self.original_sel is None:
self.original_sel = sel
10 changes: 8 additions & 2 deletions deepmd/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,19 @@ def _pass_filter(self,
start_index = 0
inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
output = []
if not self.type_one_side:
if not (self.type_one_side and len(self.exclude_types) == 0):
for type_i in range(self.ntypes):
inputs_i = tf.slice (inputs,
[ 0, start_index* self.ndescrpt],
[-1, natoms[2+type_i]* self.ndescrpt] )
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
layer = self._filter_r(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
if self.type_one_side:
# reuse NN parameters for all types to support type_one_side along with exclude_types
reuse = tf.AUTO_REUSE
filter_name = 'filter_type_all'+suffix
else:
filter_name = 'filter_type_'+str(type_i)+suffix
layer = self._filter_r(inputs_i, type_i, name=filter_name, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
output.append(layer)
start_index += natoms[2+type_i]
Expand Down
1 change: 1 addition & 0 deletions deepmd/entrypoints/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def compress(
else:
log.info("stage 0: compute the min_nbor_dist")
jdata = j_loader(training_script)
jdata = update_deepmd_input(jdata)
t_min_nbor_dist = get_min_nbor_dist(jdata, get_rcut(jdata))

_check_compress_type(input)
Expand Down
7 changes: 5 additions & 2 deletions deepmd/entrypoints/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from deepmd.utils.convert import convert_20_to_21, convert_13_to_21, convert_12_to_21
from deepmd.utils.convert import convert_10_to_21, convert_20_to_21, convert_13_to_21, convert_12_to_21

def convert(
*,
Expand All @@ -7,7 +7,10 @@ def convert(
output_model: str,
**kwargs,
):
if FROM == '1.2':
if FROM == '1.0':
convert_10_to_21(input_model, output_model)
elif FROM in ['1.1', '1.2']:
# no difference between 1.1 and 1.2
convert_12_to_21(input_model, output_model)
elif FROM == '1.3':
convert_13_to_21(input_model, output_model)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def parse_args(args: Optional[List[str]] = None):
parser_transform.add_argument(
'FROM',
type = str,
choices = ['1.2', '1.3', '2.0'],
choices = ['1.0', '1.1', '1.2', '1.3', '2.0'],
help="The original model compatibility",
)
parser_transform.add_argument(
Expand Down
Loading

0 comments on commit c4f0cec

Please sign in to comment.