From ada465ec3d6ccd4f061d9e5be77a24d22a259167 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 23 Sep 2022 18:25:37 +0800 Subject: [PATCH] set ascend_mix --- source/tests/test_ascend_transfer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/source/tests/test_ascend_transfer.py b/source/tests/test_ascend_transfer.py index a3b5227058..a63086f5cf 100644 --- a/source/tests/test_ascend_transfer.py +++ b/source/tests/test_ascend_transfer.py @@ -11,7 +11,6 @@ from deepmd.utils.graph import get_tensor_by_name from deepmd.env import GLOBAL_NP_FLOAT_PRECISION -os.environ["DP_INTERFACE_PREC"] = "ascend_mix" def _file_delete(file) : if os.path.exists(file): @@ -30,6 +29,8 @@ def _subprocess_run(command): class TestTransform(unittest.TestCase) : @classmethod def setUpClass(self): + self.env = os.environ.get("DP_INTERFACE_PREC") + os.environ["DP_INTERFACE_PREC"] = "ascend_mix" self.old_model = str(tests_path / "dp-old.pb") self.new_model = str(tests_path / "dp-ascend.pb") convert_pbtxt_to_pb(str(tests_path / os.path.join("infer","deeppot-2.pbtxt")), self.old_model) @@ -55,6 +56,11 @@ def tearDownClass(self): _file_delete("input_v2_compat.json") _file_delete("lcurve.out") shutil.rmtree("model-transfer") + if self.env: + os.environ["DP_INTERFACE_PREC"] = self.env + else: + del os.environ['DP_INTERFACE_PREC'] + def test_attrs(self): self.assertEqual(self.dp.get_ntypes(), 2)