Skip to content

Commit

Permalink
sort by atom idx for dump (#388)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Dec 2, 2022
1 parent be729bd commit b41d595
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 9 deletions.
16 changes: 7 additions & 9 deletions dpdata/lammps/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@ def get_atype(lines, type_idx_zero = False) :
tidx = keys.index('type') - 2
atype = []
for ii in blk :
atype.append([int(ii.split()[tidx]), int(ii.split()[id_idx])])
# sort with type id
atype.append([int(ii.split()[id_idx]), int(ii.split()[tidx])])
atype.sort()
atype = np.array(atype, dtype = int)
atype = atype[:, ::-1]
if type_idx_zero :
return atype[:,1] - 1
else :
Expand Down Expand Up @@ -78,17 +76,15 @@ def safe_get_posi(lines,cell,orig=np.zeros(3), unwrap=False) :
assert coord_tp_and_sf is not None, 'Dump file does not contain atomic coordinates!'
coordtype, sf, uw = coord_tp_and_sf
id_idx = keys.index('id') - 2
tidx = keys.index('type') - 2
xidx = keys.index(coordtype[0])-2
yidx = keys.index(coordtype[1])-2
zidx = keys.index(coordtype[2])-2
sel = (xidx, yidx, zidx)
posis = []
for ii in blk :
words = ii.split()
posis.append([float(words[tidx]), float(words[id_idx]), float(words[xidx]), float(words[yidx]), float(words[zidx])])
posis.append([float(words[id_idx]), float(words[xidx]), float(words[yidx]), float(words[zidx])])
posis.sort()
posis = np.array(posis)[:,2:5]
posis = np.array(posis)[:,1:4]
if not sf:
posis = (posis - orig) @ np.linalg.inv(cell) # Convert to scaled coordinates for unscaled coordinates
if uw and unwrap:
Expand Down Expand Up @@ -178,14 +174,16 @@ def system_data(lines, type_map = None, type_idx_zero = True, unwrap=False) :
orig, cell = dumpbox2box(bounds, tilt)
system['orig'] = np.array(orig) - np.array(orig)
system['cells'] = [np.array(cell)]
natoms = sum(system['atom_numbs'])
system['atom_types'] = get_atype(lines, type_idx_zero = type_idx_zero)
system['coords'] = [safe_get_posi(lines, cell, np.array(orig), unwrap)]
for ii in range(1, len(array_lines)) :
bounds, tilt = get_dumpbox(array_lines[ii])
orig, cell = dumpbox2box(bounds, tilt)
system['cells'].append(cell)
system['coords'].append(safe_get_posi(array_lines[ii], cell, np.array(orig), unwrap))
atype = get_atype(array_lines[ii], type_idx_zero = type_idx_zero)
# map atom type; a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
idx = np.argsort(atype)[np.argsort(np.argsort(system['atom_types']))]
system['coords'].append(safe_get_posi(array_lines[ii], cell, np.array(orig), unwrap)[idx])
system['cells'] = np.array(system['cells'])
system['coords'] = np.array(system['coords'])
return system
Expand Down
11 changes: 11 additions & 0 deletions tests/poscars/conf2.dump
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
ITEM: TIMESTEP
0
ITEM: NUMBER OF ATOMS
2
ITEM: BOX BOUNDS xy xz yz pp pp pp
0.0 5.0739861 1.2621856
0.0 2.7916155 1.2874292
0.0 2.2254033 0.7485898
ITEM: ATOMS id type x y z
1 2 0.0 0.0 0.0
2 1 1.2621856 0.7018028 0.5513885
21 changes: 21 additions & 0 deletions tests/test_lammps_dump_idx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# The index should map to that in the dump file

import os
import numpy as np
import unittest
from context import dpdata

class TestLmpDumpIdx(unittest.TestCase):
def setUp(self):
self.system = dpdata.System(os.path.join('poscars', 'conf2.dump'))

def test_coords(self):
np.testing.assert_allclose(self.system['coords'], np.array(
[[[0., 0., 0.],
[1.2621856, 0.7018028, 0.5513885]]]
))

def test_type(self):
np.testing.assert_allclose(self.system.get_atom_types(), np.array(
[1, 0], dtype=int,
))
1 change: 1 addition & 0 deletions tests/test_lammps_read_from_trajs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def setUp(self):
dpdata.System(os.path.join('lammps', 'traj_with_random_type_id.dump'), fmt = 'lammps/dump', type_map = ["Ta","Nb","W","Mo","V","Al"])

def test_nframes (self) :
self.system.sort_atom_types()
atype = self.system['atom_types'].tolist()
self.assertTrue(atype == [1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 5])

Expand Down

0 comments on commit b41d595

Please sign in to comment.