From 2303ff0af9240ec3a4ee70fc86a2d043b93257b3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 22 Nov 2024 23:04:01 -0500 Subject: [PATCH] fix(pt): optimize createNlistTensor (#4403) ## Summary by CodeRabbit - **New Features** - Enhanced tensor creation process for improved performance and efficiency. - **Bug Fixes** - Improved error handling for PyTorch-related exceptions, providing clearer error messages. Signed-off-by: Jinzhe Zeng --- source/api_cc/src/DeepPotPT.cc | 22 +++++++++++----------- source/api_cc/src/DeepSpinPT.cc | 22 +++++++++++----------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index b431ad65cf..6910de3ccd 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -31,20 +31,20 @@ void DeepPotPT::translate_error(std::function f) { } torch::Tensor createNlistTensor(const std::vector>& data) { - std::vector row_tensors; - + size_t total_size = 0; for (const auto& row : data) { - torch::Tensor row_tensor = torch::tensor(row, torch::kInt32).unsqueeze(0); - row_tensors.push_back(row_tensor); + total_size += row.size(); } - - torch::Tensor tensor; - if (row_tensors.size() > 0) { - tensor = torch::cat(row_tensors, 0).unsqueeze(0); - } else { - tensor = torch::empty({1, 0, 0}, torch::kInt32); + std::vector flat_data; + flat_data.reserve(total_size); + for (const auto& row : data) { + flat_data.insert(flat_data.end(), row.begin(), row.end()); } - return tensor; + + torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32); + int nloc = data.size(); + int nnei = nloc > 0 ? total_size / nloc : 0; + return flat_tensor.view({1, nloc, nnei}); } DeepPotPT::DeepPotPT() : inited(false) {} DeepPotPT::DeepPotPT(const std::string& model, diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 1a245c7b2e..aef2d60150 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -31,20 +31,20 @@ void DeepSpinPT::translate_error(std::function f) { } torch::Tensor createNlistTensor2(const std::vector>& data) { - std::vector row_tensors; - + size_t total_size = 0; for (const auto& row : data) { - torch::Tensor row_tensor = torch::tensor(row, torch::kInt32).unsqueeze(0); - row_tensors.push_back(row_tensor); + total_size += row.size(); } - - torch::Tensor tensor; - if (row_tensors.size() > 0) { - tensor = torch::cat(row_tensors, 0).unsqueeze(0); - } else { - tensor = torch::empty({1, 0, 0}, torch::kInt32); + std::vector flat_data; + flat_data.reserve(total_size); + for (const auto& row : data) { + flat_data.insert(flat_data.end(), row.begin(), row.end()); } - return tensor; + + torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32); + int nloc = data.size(); + int nnei = nloc > 0 ? total_size / nloc : 0; + return flat_tensor.view({1, nloc, nnei}); } DeepSpinPT::DeepSpinPT() : inited(false) {} DeepSpinPT::DeepSpinPT(const std::string& model,