Skip to content

Commit

Permalink
Enable pretraining for excited states, add pure-JAX GTO, make Scf tra…
Browse files Browse the repository at this point in the history
…ceable.

PiperOrigin-RevId: 578476292
Change-Id: I06144b9df36916546ca4c0caffe3777c49f1bac8
  • Loading branch information
dpfau authored and jsspencer committed Nov 24, 2023
1 parent 51fad8f commit dd4571e
Show file tree
Hide file tree
Showing 8 changed files with 748 additions and 41 deletions.
2 changes: 1 addition & 1 deletion ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def default() -> ml_collections.ConfigDict:
'pretrain': {
'method': 'hf', # Currently only 'hf' is supported.
'iterations': 1000, # Only used if method is 'hf'.
'basis': 'sto-6g',
'basis': 'ccpvdz', # Larger than STO-6G, but good for excited states
},
})

Expand Down
2 changes: 1 addition & 1 deletion ferminet/configs/li_excited.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_config():
cfg.system.charge = 0
cfg.system.delta_charge = 0.0
cfg.system.states = 3
cfg.pretrain.iterations = 0
cfg.pretrain.iterations = 1000
cfg.optim.reset_if_nan = True
cfg.system.spin_polarisation = ml_collections.FieldReference(
None, field_type=int)
Expand Down
99 changes: 77 additions & 22 deletions ferminet/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def get_hf(molecule: Optional[Sequence[system.Atom]] = None,
nspins: Optional[Tuple[int, int]] = None,
basis: Optional[str] = 'sto-3g',
pyscf_mol: Optional[pyscf.gto.Mole] = None,
restricted: Optional[bool] = False) -> scf.Scf:
restricted: Optional[bool] = False,
states: int = 0) -> scf.Scf:
"""Returns an Scf object with the Hartree-Fock solution to the system.
Args:
Expand All @@ -46,13 +47,16 @@ def get_hf(molecule: Optional[Sequence[system.Atom]] = None,
molecule, nspins and basis are ignored.
restricted: If true, perform a restricted Hartree-Fock calculation,
otherwise perform an unrestricted Hartree-Fock calculation.
states: Number of excited states. If nonzero, compute all single and double
excitations of the Hartree-Fock solution and return coefficients for the
lowest ones.
"""
if pyscf_mol:
scf_approx = scf.Scf(pyscf_mol=pyscf_mol, restricted=restricted)
else:
scf_approx = scf.Scf(
molecule, nelectrons=nspins, basis=basis, restricted=restricted)
scf_approx.run()
scf_approx.run(excitations=max(states - 1, 0))
return scf_approx


Expand Down Expand Up @@ -116,7 +120,10 @@ def make_pretrain_step(
batch_orbitals: networks.OrbitalFnLike,
batch_network: networks.LogFermiNetLike,
optimizer_update: optax.TransformUpdateFn,
scf_approx: scf.Scf,
electrons: Tuple[int, int],
full_det: bool = False,
states: int = 0,
):
"""Creates function for performing one step of Hartre-Fock pretraining.
Expand All @@ -129,43 +136,87 @@ def make_pretrain_step(
magnitude of the (wavefunction) network evaluated at those positions.
optimizer_update: callable for transforming the gradients into an update (ie
conforms to the optax API).
scf_approx: an scf.Scf object that contains the result of a PySCF
calculation.
electrons: number of spin-up and spin-down electrons.
full_det: If true, evaluate all electrons in a single determinant.
Otherwise, evaluate products of alpha- and beta-spin determinants.
states: Number of excited states, if not 0.
Returns:
Callable for performing a single pretraining optimisation step.
"""

def pretrain_step(data, target, params, state, key, logprob):
def pretrain_step(data, params, state, key, logprob):
"""One iteration of pretraining to match HF."""

cnorm = lambda x, y: (x - y) * jnp.conj(x - y) # complex norm
def loss_fn(
params: networks.ParamTree, data: networks.FermiNetData, target: ...
params: networks.ParamTree,
data: networks.FermiNetData,
):
orbitals = batch_orbitals(
params, data.positions, data.spins, data.atoms, data.charges
)
pos = data.positions
spins = data.spins
if states:
# Make vmap-ed versions of eval_orbitals and batch_orbitals over the
# states dimension.
# (batch, states, nelec*ndim)
pos = jnp.reshape(pos, pos.shape[:-1] + (states, -1))
# (batch, states, nelec)
spins = jnp.reshape(spins, spins.shape[:-1] + (states, -1))

scf_orbitals = jax.vmap(
scf_approx.eval_orbitals, in_axes=(-2, None), out_axes=-4
)

def net_orbitals(params, pos, spins, atoms, charges):
vmapped_orbitals = jax.vmap(
batch_orbitals, in_axes=(None, -2, -2, None, None), out_axes=-4
)
# Dimensions of result are
# [(batch, states, ndet*states, nelec, nelec)]
result = vmapped_orbitals(params, pos, spins, atoms, charges)
result = [
jnp.reshape(r, r.shape[:-3] + (states, -1) + r.shape[-2:])
for r in result
]
result = [jnp.transpose(r, (0, 3, 1, 2, 4, 5)) for r in result]
# We draw distinct samples for each excited state (electron
# configuration), and then evaluate each state within each sample.
# Output dimensions are:
# (batch, det, electron configuration,
# excited state, electron, orbital)
return result

else:
scf_orbitals = scf_approx.eval_orbitals
net_orbitals = batch_orbitals

target = scf_orbitals(pos, electrons)
orbitals = net_orbitals(params, pos, spins, data.atoms, data.charges)
if full_det:
ndet = target[0].shape[0]
na = target[0].shape[1]
nb = target[1].shape[1]
dims = target[0].shape[:-2] # (batch) or (batch, states).
na = target[0].shape[-2]
nb = target[1].shape[-2]
target = jnp.concatenate(
(jnp.concatenate((target[0], jnp.zeros((ndet, na, nb))), axis=-1),
jnp.concatenate((jnp.zeros((ndet, nb, na)), target[1]), axis=-1)),
axis=-2)
(
jnp.concatenate(
(target[0], jnp.zeros(dims + (na, nb))), axis=-1),
jnp.concatenate(
(jnp.zeros(dims + (nb, na)), target[1]), axis=-1),
),
axis=-2,
)
result = jnp.mean(cnorm(target[:, None, ...], orbitals[0])).real
else:
result = jnp.array(
[
jnp.mean(cnorm(t[:, None, ...], o)).real
for t, o in zip(target, orbitals)
]
).sum()
result = jnp.array([
jnp.mean(cnorm(t[:, None, ...], o)).real
for t, o in zip(target, orbitals)
]).sum()
return constants.pmean(result)

val_and_grad = jax.value_and_grad(loss_fn, argnums=0)
loss_val, search_direction = val_and_grad(params, data, target)
loss_val, search_direction = val_and_grad(params, data)
search_direction = constants.pmean(search_direction)
updates, state = optimizer_update(search_direction, state, params)
params = optax.apply_updates(params, updates)
Expand All @@ -191,6 +242,7 @@ def pretrain_hartree_fock(
scf_approx: scf.Scf,
iterations: int = 1000,
logger: Optional[Callable[[int, float], None]] = None,
states: int = 0,
):
"""Performs training to match initialization as closely as possible to HF.
Expand All @@ -216,6 +268,7 @@ def pretrain_hartree_fock(
iterations: number of pretraining iterations to perform.
logger: Callable with signature (step, value) which externally logs the
pretraining loss.
states: Number of excited states, if not 0.
Returns:
params, positions: Updated network parameters and MCMC configurations such
Expand All @@ -234,7 +287,10 @@ def pretrain_hartree_fock(
batch_orbitals,
batch_network,
optimizer.update,
scf_approx=scf_approx,
electrons=electrons,
full_det=network_options.full_det,
states=states,
)
pretrain_step = constants.pmap(pretrain_step)
pnetwork = constants.pmap(batch_network)
Expand All @@ -247,10 +303,9 @@ def pretrain_hartree_fock(
logprob = 2.0 * pnetwork(params, positions, pmap_spins, atoms, charges)

for t in range(iterations):
target = eval_orbitals(scf_approx, data.positions, electrons)
sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
data, params, opt_state_pt, loss, logprob = pretrain_step(
data, target, params, opt_state_pt, subkeys, logprob)
data, params, opt_state_pt, subkeys, logprob)
logging.info('Pretrain iter %05d: %g', t, loss[0])
if logger:
logger(t, loss[0])
Expand Down
2 changes: 1 addition & 1 deletion ferminet/tests/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_training_step(self, system, optimizer, complex_, states):
cfg.network.complex = complex_
cfg.batch_size = 32
cfg.system.states = states
cfg.pretrain.iterations = 10 if states == 0 else 0
cfg.pretrain.iterations = 10
cfg.mcmc.burn_in = 10
cfg.optim.optimizer = optimizer
cfg.optim.iterations = 3
Expand Down
7 changes: 3 additions & 4 deletions ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,13 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None):
# Create parameters, network, and vmaped/pmaped derivations

if cfg.pretrain.method == 'hf' and cfg.pretrain.iterations > 0:
if cfg.system.states > 1:
raise NotImplementedError(
'Pretraining not yet implemented for excited states')
hartree_fock = pretrain.get_hf(
pyscf_mol=cfg.system.get('pyscf_mol'),
molecule=cfg.system.molecule,
nspins=nspins,
restricted=False,
basis=cfg.pretrain.basis)
basis=cfg.pretrain.basis,
states=cfg.system.states)
# broadcast the result of PySCF from host 0 to all other hosts
hartree_fock.mean_field.mo_coeff = multihost_utils.broadcast_one_to_all(
hartree_fock.mean_field.mo_coeff
Expand Down Expand Up @@ -584,6 +582,7 @@ def log_network(*args, **kwargs):
electrons=cfg.system.electrons,
scf_approx=hartree_fock,
iterations=cfg.pretrain.iterations,
states=cfg.system.states,
)

# Main training
Expand Down
Loading

0 comments on commit dd4571e

Please sign in to comment.