From e55200aec1ee22edd32c8f93d7e7bf7fbf161184 Mon Sep 17 00:00:00 2001 From: Pierre-Francois Leget Date: Tue, 17 Sep 2024 09:01:48 -0700 Subject: [PATCH] continue work on Clares comments --- .../lsst/meas/algorithms/gp_interpolation.py | 158 ++++++------------ 1 file changed, 53 insertions(+), 105 deletions(-) diff --git a/python/lsst/meas/algorithms/gp_interpolation.py b/python/lsst/meas/algorithms/gp_interpolation.py index c66233ade..5eadbf4aa 100644 --- a/python/lsst/meas/algorithms/gp_interpolation.py +++ b/python/lsst/meas/algorithms/gp_interpolation.py @@ -73,14 +73,16 @@ def median_with_mad_clipping(data, mad_multiplier=2.0): @jit -def jax_rbf_kernel(x, sigma, correlation_length, y_err): +def jax_rbf_kernel(x1, x2, sigma, correlation_length, y_err): """ Computes the radial basis function (RBF) kernel matrix. Parameters: ----------- - x : `np.array` - Input data points with shape (n_samples, n_features). + x1 : `np.array` + Location of training data point with shape (n_samples, n_features). + x2 : `np.array` + Location of training/test data point with shape (n_samples, n_features). sigma : `float` The scale parameter of the kernel. correlation_length : `float` @@ -93,39 +95,12 @@ def jax_rbf_kernel(x, sigma, correlation_length, y_err): kernel : `np.array` RBF kernel matrix with shape (n_samples, n_samples). """ - distance_squared = jnp.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1) + distance_squared = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1) kernel = (sigma**2) * jnp.exp(-0.5 * distance_squared / (correlation_length**2)) y_err = jnp.ones(len(x[:, 0])) * y_err kernel += jnp.eye(len(y_err)) * (y_err**2) return kernel - -@jit -def jax_rbf_kernel_rect(x1, x2, sigma, correlation_length): - """ - Compute the radial basis function (RBF) kernel (rectangular matrix). - Parameters: - ----------- - x1 : `np.array` - The first set of input points. - x2 : `np.array` - The second set of input points. - sigma : `float` - The scale parameter of the kernel. - correlation_length : float - The correlation length parameter of the kernel. - - Returns: - -------- - kernel_rect : `np.array` - The computed RBF kernel (rectangular matrix). - - """ - l1 = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1) - kernel_rect = (sigma**2) * jnp.exp(-0.5 * l1 / (correlation_length**2)) - return kernel_rect - - @jit def jax_get_alpha(y, kernel): """ @@ -141,7 +116,7 @@ def jax_get_alpha(y, kernel): Returns: -------- alpha : `np.array` - The alpha vector computed using the Cholesky decomposition and solve. + The alpha vector computed using the Cholesky decomposition and solution. """ factor = (jax.scipy.linalg.cholesky(kernel, overwrite_a=True, lower=False), False) @@ -150,16 +125,16 @@ def jax_get_alpha(y, kernel): @jit -def jax_get_y_predict(kernel_rect, alpha): +def jax_get_gp_predict(kernel_rect, alpha): """ - Compute the predicted values of y using the given kernel and alpha. + Compute the predicted values of gp using the given kernel and alpha (cholesky solution). Parameters: ----------- kernel_rect : `np.array` The kernel matrix. alpha : `np.array` - The alpha vector. + The alpha vector from Cholesky solution. Returns: -------- @@ -204,24 +179,23 @@ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): self.mean = mean self._alpha = None - def fit(self, x_good, y_good): - y = y_good - self.mean - self._x = x_good - kernel = jax_rbf_kernel(x_good, self.std, self.correlation_length, self.white_noise) + def fit(self, x_train, y_train): + y = y_train - self.mean + self._x = x_train + kernel = jax_rbf_kernel(x_train, self.std, self.correlation_length, self.white_noise) self._alpha = jax_get_alpha(y, kernel) - def predict(self, x_bad): - kernel_rect = jax_rbf_kernel_rect(x_bad, self._x, self.std, self.correlation_length) - y_pred = jax_get_y_predict(kernel_rect, self._alpha) + def predict(self, x_predict): + kernel_rect = jax_rbf_kernel(x_predict, self._x, self.std, self.correlation_length, 0) + y_pred = jax_get_gp_predict(kernel_rect, self._alpha) return y_pred + self.mean - -# Vanilla Gaussian Process regression using treegp package -# There is no fancy O(N*log(N)) solver here, just the basic GP regression (Cholesky). class GaussianProcessTreegp: """ Gaussian Process Treegp class for Gaussian Process interpolation. + The basic GP regression, which uses Cholesky decomposition. + Parameters: ----------- std : `float`, optional @@ -240,15 +214,15 @@ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): self.white_noise = white_noise self.mean = mean - def fit(self, x_good, y_good): + def fit(self, x_train, y_train): """ Fit the Gaussian Process to the given training data. Parameters: ----------- - x_good : `np.array` + x_train : `np.array` Input features for the training data. - y_good : `np.array` + y_train : `np.array` Target values for the training data. """ kernel = f"{self.std}**2 * RBF({self.correlation_length})" @@ -258,16 +232,16 @@ def fit(self, x_good, y_good): normalize=False, white_noise=self.white_noise, ) - self.gp.initialize(x_good, y_good - self.mean) + self.gp.initialize(x_train, y_train - self.mean) self.gp.solve() - def predict(self, x_bad): + def predict(self, x_predict): """ Predict the target values for the given input features. Parameters: ----------- - x_bad : `np.array` + x_predict : `np.array` Input features for the prediction. Returns: @@ -275,7 +249,7 @@ def predict(self, x_bad): y_pred : `np.array` Predicted target values. """ - y_pred = self.gp.predict(x_bad) + y_pred = self.gp.predict(x_predict) return y_pred + self.mean @@ -303,30 +277,6 @@ class InterpolateOverDefectGaussianProcess: correlation_length_cut : `int`, optional The factor by which to dilate the bounding box around defects. Default is 5. - Raises: - ------- - ValueError - If an invalid method is provided. - - Attributes: - ----------- - bin_spacing : `int` - The spacing between bins for good pixel binning. - threshold_subdivide : `int` - The threshold for sub-dividing the bad pixel array to avoid memory error. - threshold_dynamic_binning : `int` - The threshold for dynamic binning. - maskedImage : `lsst.afw.image.MaskedImage` - The masked image containing the defects to be interpolated. - defects : `list`[`str`] - The types of defects to be interpolated. - correlation_length : `float` - The correlation length (FWHM). - correlation_length_cut : `int` - The factor by which to dilate the bounding box around defects. - interpBit : `int` - The bit mask for the "INTRP" plane in the image mask. - """ def __init__( @@ -360,7 +310,9 @@ def __init__( def interpolate_over_defects(self): """ - Interpolates over the defects in the image. + Interpolate over the defects in the image. + + Change self.maskedImage . """ mask = self.maskedImage.getMask() @@ -368,20 +320,20 @@ def interpolate_over_defects(self): badMaskSpanSet = SpanSet.fromMask(mask, badPixelMask).split() bbox = self.maskedImage.getBBox() - glob_xmin, glob_xmax = bbox.minX, bbox.maxX - glob_ymin, glob_ymax = bbox.minY, bbox.maxY + global_xmin, global_xmax = bbox.minX, bbox.maxX + global_ymin, global_ymax = bbox.minY, bbox.maxY for spanset in badMaskSpanSet: bbox = spanset.getBBox() # Dilate the bbox to make sure we have enough good pixels around the defect # For now, we dilate by 5 times the correlation length - # For GP with isotropic kernel, points at 5 correlation lengths away have negligible - # effect on the prediction. + # For GP with the isotropic kernel, points at the default value of + # correlation_length_cut=5 have negligible effect on the prediction. bbox = bbox.dilatedBy( int(self.correlation_length * self.correlation_length_cut) ) # need integer as input. - xmin, xmax = max([glob_xmin, bbox.minX]), min(glob_xmax, bbox.maxX) - ymin, ymax = max([glob_ymin, bbox.minY]), min(glob_ymax, bbox.maxY) + xmin, xmax = max([global_xmin, bbox.minX]), min(global_xmax, bbox.maxX) + ymin, ymax = max([global_ymin, bbox.minY]), min(global_ymax, bbox.maxY) localBox = Box2I(Point2I(xmin, ymin), Extent2I(xmax - xmin, ymax - ymin)) try: sub_masked_image = self.maskedImage[localBox] @@ -391,22 +343,22 @@ def interpolate_over_defects(self): sub_masked_image = self.interpolate_sub_masked_image(sub_masked_image) self.maskedImage[localBox] = sub_masked_image - def _good_pixel_binning(self, good_pixel): + def _good_pixel_binning(self, pixels): """ - Performs good pixel binning. + Performs pixel binning using treegp.meanify Parameters: ----------- - good_pixel : `np.array` - The array of good pixels. + pixels : `np.array` + The array of pixels. Returns: -------- `np.array` - The binned array of good pixels. + The binned array of pixels. """ - n_pixels = len(good_pixel[:, 0]) + n_pixels = len(pixels[:, 0]) dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning)) if n_pixels / self.bin_spacing**2 < n_pixels / dynamic_binning**2: bin_spacing = self.bin_spacing @@ -414,8 +366,8 @@ def _good_pixel_binning(self, good_pixel): bin_spacing = dynamic_binning binning = treegp.meanify(bin_spacing=bin_spacing, statistics="mean") binning.add_field( - good_pixel[:, :2], - good_pixel[:, 2:].T, + pixels[:, :2], + pixels[:, 2:].T, ) binning.meanify() return np.array( @@ -454,7 +406,7 @@ def interpolate_sub_masked_image(self, sub_masked_image): white_noise = np.sqrt( np.mean(sub_image_array[np.isfinite(sub_image_array)]) ) - kernel_amplitude = np.std(good_pixel[:, 2:]) + kernel_amplitude = np.max(good_pixel[:, 2:]) if not np.isfinite(kernel_amplitude): filter_finite = np.isfinite(good_pixel[:, 2:]).T[0] good_pixel = good_pixel[filter_finite] @@ -466,23 +418,21 @@ def interpolate_sub_masked_image(self, sub_masked_image): # kernel amplitude might be better described by maximum value of good pixel given # the data and not really a random gaussian field. kernel_amplitude = np.max(good_pixel[:, 2:]) - else: - kernel_amplitude = np.max(good_pixel[:, 2:]) try: good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel)) except Exception: warnings.warn( - "Binning failed, use original good pixel array in interpolate over." + "Binning failed, use original good pixel array in interpolation." ) - # put this after binning as comupting median is O(n*log(n)) - mean = median_with_mad_clipping(good_pixel[:, 2:]) + # put this after binning as computing median is O(n*log(n)) + clipped_median = median_with_mad_clipping(good_pixel[:, 2:]) gp = self.GaussianProcess( std=np.sqrt(kernel_amplitude), correlation_length=self.correlation_length, white_noise=white_noise, - mean=mean, + mean=clipped_median, ) gp.fit(good_pixel[:, :2], np.squeeze(good_pixel[:, 2:])) if bad_pixel.size < self.threshold_subdivide: @@ -497,7 +447,7 @@ def interpolate_sub_masked_image(self, sub_masked_image): np.shape(bad_pixel[i:end, 2:]) ) - # update_value + # Update values ctUtils.updateImageFromArray(sub_masked_image.image, bad_pixel) updateMaskFromArray(sub_masked_image.mask, bad_pixel, self.interpBit) return sub_masked_image @@ -515,6 +465,10 @@ def interpolateOverDefectsGP( """ Interpolate over defects in an image using Gaussian Process interpolation. + This function performs Gaussian Process interpolation over defects in the input image. + It uses the provided defect coordinates to identify and interpolate over the defects. + The interpolated image is not returned, instead, the input image is modified in-place. + Parameters ---------- image : `np.array` @@ -541,12 +495,6 @@ def interpolateOverDefectsGP( UserWarning If no defects are found in the image. - Notes - ----- - This function performs Gaussian Process interpolation over defects in the input image. - It uses the provided defect coordinates to identify and interpolate over the defects. - The interpolated image is not returned, instead, the input image is modified in-place. - """ if badList == [] or badList is None: warnings.warn("WARNING: no defects found. No interpolation performed.")