diff --git a/python/lsst/meas/algorithms/gp_interpolation.py b/python/lsst/meas/algorithms/gp_interpolation.py index b95397c6d..672588b54 100644 --- a/python/lsst/meas/algorithms/gp_interpolation.py +++ b/python/lsst/meas/algorithms/gp_interpolation.py @@ -14,19 +14,6 @@ __all__ = ["interpolateOverDefectsGP"] -# TO DO: REMOVE this helper function. -def get_name_pkl(): - """ - Get name. - """ - import glob - import os - rep = '/sdf/home/l/leget/rubin-user/lsst_dev/tickets/DM-44305/data/' - all_pkl = glob.glob(os.path.join(rep, 'out_test_*')) - n_file = len(all_pkl) - file_name = f"out_test_{n_file}.pkl" - return os.path.join(rep, file_name) - @jit def median_with_mad_clipping(data, mad_multiplier=2.0): @@ -57,14 +44,6 @@ def jax_get_y_predict(HT, alpha): @jit def jax_rbf_k(x1, sigma, correlation_length, y_err): - """Compute the RBF kernel with JAX. - - :param x1: The first set of points. (n_samples,) - :param sigma: The amplitude of the kernel. - :param correlation_length: The correlation length of the kernel. - :param y_err: The error of the field. (n_samples) - :param white_noise: The white noise of the field. - """ l1 = jax_pdist_squared(x1) K = (sigma**2) * jnp.exp(-0.5 * l1 / (correlation_length**2)) @@ -74,32 +53,15 @@ def jax_rbf_k(x1, sigma, correlation_length, y_err): @jit def jax_rbf_h(x1, x2, sigma, correlation_length): - """Compute the RBF kernel with JAX. - - :param x1: The first set of points. (n_samples,) - :param x2: The second set of points. (n_samples) - :param sigma: The amplitude of the kernel. - :param correlation_length: The correlation length of the kernel. - :param y_err: The error of the field. (n_samples) - :param white_noise: The white noise of the field. - """ + l1 = jax_cdist_squared(x1, x2) K = (sigma**2) * jnp.exp(-0.5 * l1 / (correlation_length**2)) return K -def updateMaskFromArray(mask, bad_pixel, interpBit): - x0 = mask.getX0() - y0 = mask.getY0() - for row in bad_pixel: - x = int(row[0] - x0) - y = int(row[1] - y0) - mask.array[y, x] |= interpBit class GaussianProcessJax: def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): - """ - TO DO - """ + self.std = std self.l = correlation_length self.white_noise = white_noise @@ -107,14 +69,7 @@ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): self._alpha = None def fit(self, x_good, y_good): - """ - Fits the Gaussian Process regression model to the given training data. - Args: - x_good (array-like): The input features of the training data. - y_good (array-like): The target values of the training data. - - """ y = y_good - self.mean self._x = x_good K = jax_rbf_k(x_good, self.std, self.l, self.white_noise) @@ -122,72 +77,35 @@ def fit(self, x_good, y_good): def predict(self, x_bad): - """ - Makes predictions using the fitted Gaussian Process regression model. - - Args: - x (array-like): The input features for which to make predictions. - - Returns: - array-like: The predicted target values. - """ HT = jax_rbf_h(x_bad, self._x, self.std, self.l) y_pred = jax_get_y_predict(HT, self._alpha) return y_pred + self.mean - +def updateMaskFromArray(mask, bad_pixel, interpBit): + x0 = mask.getX0() + y0 = mask.getY0() + for row in bad_pixel: + x = int(row[0] - x0) + y = int(row[1] - y0) + mask.array[y, x] |= interpBit # Vanilla Gaussian Process regression using treegp package # There is no fancy O(N*log(N)) solver here, just the basic GP regression (Cholesky). class GaussianProcessTreegp: - """ - Gaussian Process regression using treegp package. - - This class implements Gaussian Process regression using the treegp package. It provides methods for fitting the - regression model and making predictions. - - Attributes: - std (float): The standard deviation parameter for the Gaussian Process kernel. - l (float): The correlation length parameter for the Gaussian Process kernel. - white_noise (float): The white noise parameter for the Gaussian Process kernel. - mean (float): The mean parameter for the Gaussian Process kernel. - - Methods: - fit(x_good, y_good): Fits the Gaussian Process regression model to the given training data. - predict(x_bad): Makes predictions using the fitted Gaussian Process regression model. - - """ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): - """ - Initializes a new instance of the gp_treegp class. - Args: - std (float, optional): The standard deviation parameter for the Gaussian Process kernel. Defaults to 2. - correlation_length (float, optional): The correlation length parameter for the Gaussian Process kernel. - Defaults to 1. - white_noise (float, optional): The white noise parameter for the Gaussian Process kernel. Defaults to 0. - mean (float, optional): The mean parameter for the Gaussian Process kernel. Defaults to 0. - - """ self.std = std self.l = correlation_length self.white_noise = white_noise self.mean = mean def fit(self, x_good, y_good): - """ - Fits the Gaussian Process regression model to the given training data. - - Args: - x_good (array-like): The input features of the training data. - y_good (array-like): The target values of the training data. - """ - KERNEL = "%.2f**2 * RBF(%f)" % ((self.std, self.l)) + kernel = f"{self.std}**2 * RBF({self.l})" self.gp = treegp.GPInterpolation( - kernel=KERNEL, + kernel=kernel, optimizer="none", normalize=False, white_noise=self.white_noise, @@ -196,30 +114,12 @@ def fit(self, x_good, y_good): self.gp.solve() def predict(self, x_bad): - """ - Makes predictions using the fitted Gaussian Process regression model. - - Args: - x (array-like): The input features for which to make predictions. - Returns: - array-like: The predicted target values. - - """ y_pred = self.gp.predict(x_bad) return y_pred + self.mean class InterpolateOverDefectGaussianProcess: - """ - Class for interpolating over defects in a masked image using Gaussian Processes. - - Args: - maskedImage (MaskedImage): The masked image containing defects. - defects (list, optional): List of defect names to interpolate over. Defaults to ["SAT"]. - fwhm (float, optional): FWHM from PSF and used as prior for correlation length. Defaults to 5. - bin_spacing (float, optional): Spacing for binning. Defaults to 10. - """ def __init__( self, @@ -230,24 +130,15 @@ def __init__( bin_spacing=10, threshold_dynamic_binning=1000, threshold_subdivide=20000, + correlation_length_cut=5, ): - """ - Initializes the InterpolateOverDefectGaussianProcess class. - - Args: - maskedImage (MaskedImage): The masked image containing defects. - defects (list, optional): List of defect names to interpolate over. Defaults to ["SAT"]. - fwhm (float, optional): FWHM from PSF and used as prior for correlation length. Defaults to 5. - bin_spacing (float, optional): Spacing for binning. Defaults to 10. - """ - - if method not in ["jax", "treegp"]: - raise ValueError("Invalid method. Must be 'jax' or 'treegp'.") if method == "jax": self.GaussianProcess = GaussianProcessJax - if method == "treegp": + elif method == "treegp": self.GaussianProcess = GaussianProcessTreegp + else: + raise ValueError("Invalid method. Must be 'jax' or 'treegp'.") self.bin_spacing = bin_spacing self.threshold_subdivide = threshold_subdivide @@ -256,16 +147,13 @@ def __init__( self.maskedImage = maskedImage self.defects = defects self.correlation_length = fwhm + self.correlation_length_cut = correlation_length_cut self.interpBit = self.maskedImage.mask.getPlaneBitMask("INTRP") def interpolate_over_defects(self): - """ - Interpolates over defects using the spanset method. - """ mask = self.maskedImage.getMask() - # breakpoint() badPixelMask = mask.getPlaneBitMask(self.defects) badMaskSpanSet = SpanSet.fromMask(mask, badPixelMask).split() @@ -273,14 +161,13 @@ def interpolate_over_defects(self): glob_xmin, glob_xmax = bbox.minX, bbox.maxX glob_ymin, glob_ymax = bbox.minY, bbox.maxY - for i in range(len(badMaskSpanSet)): - spanset = badMaskSpanSet[i] + for spanset in badMaskSpanSet: bbox = spanset.getBBox() # Dilate the bbox to make sure we have enough good pixels around the defect # For now, we dilate by 5 times the correlation length # For GP with isotropic kernel, points at 5 correlation lengths away have negligible # effect on the prediction. - bbox = bbox.dilatedBy(int(self.correlation_length * 5)) # need integer as input. + bbox = bbox.dilatedBy(int(self.correlation_length * self.correlation_length_cut)) # need integer as input. xmin, xmax = max([glob_xmin, bbox.minX]), min(glob_xmax, bbox.maxX) ymin, ymax = max([glob_ymin, bbox.minY]), min(glob_ymax, bbox.maxY) localBox = Box2I(Point2I(xmin, ymin), Extent2I(xmax - xmin, ymax - ymin)) @@ -295,16 +182,7 @@ def interpolate_over_defects(self): self.maskedImage[localBox] = sub_masked_image def _good_pixel_binning(self, good_pixel): - """ - Performs binning of good pixel data. - Parameters: - - good_pixel (numpy.ndarray): An array containing the good pixel data. - - Returns: - - numpy.ndarray: An array containing the binned data. - - """ n_pixels = len(good_pixel[:,0]) dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning)) if n_pixels/self.bin_spacing**2 < n_pixels/dynamic_binning**2: @@ -318,17 +196,8 @@ def _good_pixel_binning(self, good_pixel): def interpolate_sub_masked_image(self, sub_masked_image): - """ - Interpolates over defects in a sub-masked image. - - Args: - sub_masked_image (MaskedImage): The sub-masked image containing defects. - Returns: - MaskedImage: The sub-masked image with defects interpolated. - """ - - cut = int(self.correlation_length * 5) # need integer as input. + cut = int(self.correlation_length * self.correlation_length_cut) # need integer as input. bad_pixel, good_pixel = ctUtils.findGoodPixelsAroundBadPixels( sub_masked_image, self.defects, buffer=cut ) @@ -383,33 +252,7 @@ def interpolate_sub_masked_image(self, sub_masked_image): def interpolateOverDefectsGP(image, fwhm, badList, method="treegp", bin_spacing=25, threshold_dynamic_binning=1000, threshold_subdivide=20000): - """ - Interpolates over defects in an image using Gaussian Process interpolation. - - Args: - image : The input image. - fwhm (float): The full width at half maximum (FWHM) of the PSF used for approximation of correlation lenght. - badList (list): A list of defects to interpolate over. - bin_spacing (int, optional): The spacing between bins when resampling. Defaults to 15. - threshold_subdivide (int, optional): The threshold number of bad pixels to subdivide to avoid memory issue. Defaults to 20000. - - Returns: - None - """ - #import pickle - #dic = { - # 'image': image, - # 'fwhm': fwhm, - # 'badList': badList, - # 'method': method, - # 'bin_spacing': bin_spacing, - # 'threshold_subdivide':threshold_subdivide, - # } - #file_name_out = get_name_pkl() - #fileout = open(file_name_out, 'wb') - #print(f'JE PASSE PAR LA | {file_name_out}') - #pickle.dump(dic, fileout) - #fileout.close() + if badList == [] or badList is None: warnings.warn('WARNING: no defects found. No interpolation performed.') return