Skip to content

Commit

Permalink
fix(cc): copy nloc atoms from neighbor list (#4459)
Browse files Browse the repository at this point in the history
Prevent that the size of the neighbor list is larger than nloc.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced flexibility in copying neighbor list data with the addition
of a `natoms` parameter.
- Improved handling of neighbor list data in the `compute` methods
across multiple classes.

- **Bug Fixes**
- Refined error handling in the `translate_error` method for better
clarity on exceptions.

- **Documentation**
- Updated method documentation to reflect changes in parameters and
usage.

- **Style**
	- Adjusted code structure for better readability and maintainability.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Dec 9, 2024
1 parent d162d0b commit ec3b83f
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 6 deletions.
8 changes: 7 additions & 1 deletion source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ struct NeighborListData {
std::vector<int*> firstneigh;

public:
void copy_from_nlist(const InputNlist& inlist);
/**
* @brief Copy the neighbor list from an InputNlist.
* @param[in] inlist The input neighbor list.
* @param[in] natoms The number of atoms to copy. If natoms is -1, copy all
* atoms.
*/
void copy_from_nlist(const InputNlist& inlist, const int natoms = -1);
void shuffle(const std::vector<int>& fwd_map);
void shuffle(const deepmd::AtomMap& map);
void shuffle_exclude_empty(const std::vector<int>& fwd_map);
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepPotJAX.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ void deepmd::DeepPotJAX::compute(std::vector<ENERGYTYPE>& ener,
input_list[1] = add_input(op, atype, atype_shape, data_tensor[1], status);
// nlist
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list);
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
nlist_data.shuffle_exclude_empty(fwd_map);
}
size_t max_size = 0;
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
at::Tensor atype_Tensor =
torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device);
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list);
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing) {
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepSpinPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device);
c10::optional<torch::Tensor> mapping_tensor;
if (ago == 0) {
nlist_data.copy_from_nlist(lmp_list);
nlist_data.copy_from_nlist(lmp_list, nall - nghost);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing) {
Expand Down
5 changes: 3 additions & 2 deletions source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,9 @@ template void deepmd::select_real_atoms_coord<float>(
const int& nall,
const bool aparam_nall);

void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist) {
int inum = inlist.inum;
void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist,
const int natoms) {
int inum = natoms >= 0 ? natoms : inlist.inum;
ilist.resize(inum);
jlist.resize(inum);
memcpy(&ilist[0], inlist.ilist, inum * sizeof(int));
Expand Down

0 comments on commit ec3b83f

Please sign in to comment.