From 37a273cf0a1178991b748a11cdaf18826fc33ba3 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Tue, 8 Oct 2024 14:32:27 -0700 Subject: [PATCH] Disable jax in Gaussian process interp --- .../lsst/meas/algorithms/gp_interpolation.py | 156 +----------------- tests/test_gp_interp.py | 4 +- 2 files changed, 6 insertions(+), 154 deletions(-) diff --git a/python/lsst/meas/algorithms/gp_interpolation.py b/python/lsst/meas/algorithms/gp_interpolation.py index 6de217f9a..21d799c88 100644 --- a/python/lsst/meas/algorithms/gp_interpolation.py +++ b/python/lsst/meas/algorithms/gp_interpolation.py @@ -26,15 +26,10 @@ import copy import treegp -import jax -from jax import jit -import jax.numpy as jnp - import logging __all__ = [ "InterpolateOverDefectGaussianProcess", - "GaussianProcessJax", "GaussianProcessTreegp", ] @@ -62,7 +57,6 @@ def updateMaskFromArray(mask, bad_pixel, interpBit): # TO DO --> might be better: mask.array[int(bad_pixel[:,1]-y0), int(bad_pixel[:,0]-x)] |= interpBit -@jit def median_with_mad_clipping(data, mad_multiplier=2.0): """ Calculate the median of the input data after applying Median Absolute Deviation (MAD) clipping. @@ -90,145 +84,14 @@ def median_with_mad_clipping(data, mad_multiplier=2.0): >>> median_with_mad_clipping(data) 3.5 """ - median = jnp.median(data) - mad = jnp.median(jnp.abs(data - median)) + median = np.median(data) + mad = np.median(np.abs(data - median)) clipping_range = mad_multiplier * mad - clipped_data = jnp.clip(data, median - clipping_range, median + clipping_range) - median_clipped = jnp.median(clipped_data) + clipped_data = np.clip(data, median - clipping_range, median + clipping_range) + median_clipped = np.median(clipped_data) return median_clipped -@jit -def jax_rbf_kernel(x1, x2, sigma, correlation_length): - """ - Computes the radial basis function (RBF) kernel matrix. - - Parameters: - ----------- - x1 : `np.array` - Location of training data point with shape (n_samples, n_features). - x2 : `np.array` - Location of training/test data point with shape (n_samples, n_features). - sigma : `float` - The scale parameter of the kernel. - correlation_length : `float` - The correlation length parameter of the kernel. - - Returns: - -------- - kernel : `np.array` - RBF kernel matrix with shape (n_samples, n_samples). - """ - distance_squared = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1) - kernel = (sigma**2) * jnp.exp(-0.5 * distance_squared / (correlation_length**2)) - return kernel - - -@jit -def jax_get_alpha(y, kernel): - """ - Compute the alpha vector for Gaussian Process interpolation. - - Parameters: - ----------- - y : `np.array` - The target values of the Gaussian Process. - kernel : `np.array` - The kernel matrix of the Gaussian Process. - - Returns: - -------- - alpha : `np.array` - The alpha vector computed using the Cholesky decomposition and solution. - - """ - factor = (jax.scipy.linalg.cholesky(kernel, overwrite_a=True, lower=False), False) - alpha = jax.scipy.linalg.cho_solve(factor, y, overwrite_b=False) - return alpha.reshape((len(alpha), 1)) - - -@jit -def jax_get_gp_predict(kernel_rect, alpha): - """ - Compute the predicted values of gp using the given kernel and alpha (cholesky solution). - - Parameters: - ----------- - kernel_rect : `np.array` - The kernel matrix. - alpha : `np.array` - The alpha vector from Cholesky solution. - - Returns: - -------- - `np.array` - The predicted values of y. - - """ - return jnp.dot(kernel_rect, alpha).T[0] - - -class GaussianProcessJax: - """ - Gaussian Process regression in JAX. - Kernel is assumed to be isotropic RBF kernel, and solved - using exact Cholesky decomposition. - The interpolation solution is obtained by solving the linear system: - y_interp = kernel_rect @ (kernel + y_err**2 * I)^-1 @ y_training. - See the Rasmussen and Williams book for more details. - Each function is decorated with @jit to compile the function. - Exist package like tinygp, that is implemented in jax also. - This class is a custom implementation - of Gaussian Processes, which allows for setting the hyperparameters, - fine-tuning the mean function, - and other specifications. - - Parameters: - ----------- - std : `float`, optional - Standard deviation of the Gaussian Process kernel. Default is 1.0. - correlation_length : `float`, optional - Correlation length of the Gaussian Process kernel. Default is 1.0. - white_noise : `float`, optional - White noise level of the Gaussian Process. Default is 0.0. - mean : `float`, optional - Mean value of the Gaussian Process. Default is 0.0. - - """ - - def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): - self.std = std - self.correlation_length = correlation_length - self.white_noise = white_noise - self.mean = mean - self._alpha = None - - # Looks weird to do that, but this is justified. - # in GP if no noise is provided, even if matrix - # can be inverted, it wont invert because of numerical - # issue (det(K)~0). Add a little bit of noise allow - # to compute a numerical solution in the case of no - # external noise is added. Wont happened on real - # image but help for unit test. - if self.white_noise == 0.0: - self.white_noise = 1e-5 - - def fit(self, x_train, y_train): - y = y_train - self.mean - self._x = x_train - kernel = jax_rbf_kernel(x_train, x_train, self.std, self.correlation_length) - y_err = jnp.ones(len(x_train[:, 0])) * self.white_noise - kernel += jnp.eye(len(y_err)) * (y_err**2) - self._alpha = jax_get_alpha(y, kernel) - - def predict(self, x_predict): - kernel_rect = jax_rbf_kernel( - x_predict, self._x, self.std, self.correlation_length - ) - y_pred = jax_get_gp_predict(kernel_rect, self._alpha) - return y_pred + self.mean - - class GaussianProcessTreegp: """ Gaussian Process Treegp class for Gaussian Process interpolation. @@ -313,8 +176,6 @@ class InterpolateOverDefectGaussianProcess: The masked image containing the defects to be interpolated. defects : `list`[`str`], optional The types of defects to be interpolated. Default is ["SAT"]. - method : `str`, optional - The method to use for GP interpolation. Must be either "jax" or "treegp". Default is "treegp". fwhm : `float`, optional The full width at half maximum (FWHM) of the PSF. Default is 5. bin_spacing : `int`, optional @@ -334,7 +195,6 @@ def __init__( self, masked_image, defects=["SAT"], - method="treegp", fwhm=5, bin_image=True, bin_spacing=10, @@ -343,12 +203,6 @@ def __init__( correlation_length_cut=5, log=None, ): - if method == "jax": - self.GaussianProcess = GaussianProcessJax - elif method == "treegp": - self.GaussianProcess = GaussianProcessTreegp - else: - raise ValueError("Invalid method. Must be 'jax' or 'treegp'.") self.log = log or logging.getLogger(__name__) @@ -485,7 +339,7 @@ def interpolate_masked_sub_image(self, masked_sub_image): # put this after binning as computing median is O(n*log(n)) clipped_median = median_with_mad_clipping(good_pixel[:, 2:]) - gp = self.GaussianProcess( + gp = GaussianProcessTreegp( std=np.sqrt(kernel_amplitude), correlation_length=self.correlation_length, white_noise=white_noise, diff --git a/tests/test_gp_interp.py b/tests/test_gp_interp.py index 19f5c716b..5bd5c820b 100644 --- a/tests/test_gp_interp.py +++ b/tests/test_gp_interp.py @@ -142,8 +142,7 @@ def setUp(self): # np.random.seed(12345) # self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape) - @lsst.utils.tests.methodParameters(method=("jax")) - def test_interpolation(self, method: str): + def test_interpolation(self): """Test that the interpolation is done correctly. Parameters @@ -155,7 +154,6 @@ def test_interpolation(self, method: str): gp = InterpolateOverDefectGaussianProcess( self.maskedimage, defects=["BAD", "SAT", "CR", "EDGE"], - method=method, fwhm=self.correlation_length, bin_image=False, bin_spacing=30,