From 67cf7954a93b62318f5349ec208f7153e752f705 Mon Sep 17 00:00:00 2001 From: FermiNet Contributor Date: Mon, 19 Aug 2024 13:12:35 +0100 Subject: [PATCH] Remove the `TwoKroneckerFactored` class and use the `KroneckerFactored` class instead. This is a no-op change. It just simplifies the code and makes it easier to add support for more than two parameters. PiperOrigin-RevId: 664755711 Change-Id: Ia231530d7e714a0a44b22babdc1a008de6fe9340 --- ferminet/curvature_tags_and_blocks.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ferminet/curvature_tags_and_blocks.py b/ferminet/curvature_tags_and_blocks.py index c854fbe..46881b7 100644 --- a/ferminet/curvature_tags_and_blocks.py +++ b/ferminet/curvature_tags_and_blocks.py @@ -54,13 +54,13 @@ def fixed_scale(self) -> Numeric: def update_curvature_matrix_estimate( self, - state: kfac_jax.TwoKroneckerFactored.State, + state: kfac_jax.KroneckerFactored.State, estimation_data: kfac_jax.LayerVjpData[Array], ema_old: Numeric, ema_new: Numeric, identity_weight: Numeric, batch_size: int, - ) -> kfac_jax.TwoKroneckerFactored.State: + ) -> kfac_jax.KroneckerFactored.State: [x] = estimation_data.primals.inputs [dy] = estimation_data.tangents.outputs assert x.shape[0] == batch_size @@ -88,7 +88,7 @@ def update_curvature_matrix_estimate( ) -class QmcBlockedDense(kfac_jax.TwoKroneckerFactored): +class QmcBlockedDense(kfac_jax.KroneckerFactored): """A factor that is the Kronecker product of two matrices.""" def input_size(self) -> int: @@ -102,13 +102,13 @@ def fixed_scale(self) -> Numeric: def update_curvature_matrix_estimate( self, - state: kfac_jax.TwoKroneckerFactored.State, + state: kfac_jax.KroneckerFactored.State, estimation_data: kfac_jax.LayerVjpData[Array], ema_old: Numeric, ema_new: Numeric, identity_weight: Numeric, batch_size: int, - ) -> kfac_jax.TwoKroneckerFactored.State: + ) -> kfac_jax.KroneckerFactored.State: del identity_weight [x] = estimation_data.primals.inputs @@ -132,7 +132,7 @@ def _init( exact_powers_to_cache: Set[Scalar], approx_powers_to_cache: Set[Scalar], cache_eigenvalues: bool, - ) -> kfac_jax.TwoKroneckerFactored.State: + ) -> kfac_jax.KroneckerFactored.State: del rng, cache_eigenvalues k, m, j, n = self.parameters_shapes[0] cache = dict() @@ -147,7 +147,7 @@ def _init( inputs_factor=jnp.zeros([j, k, k]), outputs_factor=jnp.zeros([j, m * n, m * n]), ) - return kfac_jax.TwoKroneckerFactored.State( + return kfac_jax.KroneckerFactored.State( cache=cache, inputs_factor= kfac_jax.utils.WeightedMovingAverage.zeros_array((j, k, k)), @@ -157,12 +157,12 @@ def _init( def _update_cache( self, - state: kfac_jax.TwoKroneckerFactored.State, + state: kfac_jax.KroneckerFactored.State, identity_weight: kfac_jax.utils.Numeric, exact_powers: set[kfac_jax.utils.Scalar], approx_powers: set[kfac_jax.utils.Scalar], eigenvalues: bool, - ) -> kfac_jax.TwoKroneckerFactored.State: + ) -> kfac_jax.KroneckerFactored.State: del eigenvalues if exact_powers: @@ -186,7 +186,7 @@ def _update_cache( def multiply_matpower( self, - state: kfac_jax.TwoKroneckerFactored.State, + state: kfac_jax.KroneckerFactored.State, vector: Sequence[Array], identity_weight: Numeric, power: Scalar,