diff --git a/python/lsst/meas/algorithms/gp_interpolation.py b/python/lsst/meas/algorithms/gp_interpolation.py index 33a5e7554..2a45cabc0 100644 --- a/python/lsst/meas/algorithms/gp_interpolation.py +++ b/python/lsst/meas/algorithms/gp_interpolation.py @@ -39,6 +39,7 @@ def updateMaskFromArray(mask, bad_pixel, interpBit): y = int(row[1] - y0) mask.array[y, x] |= interpBit + @jit def median_with_mad_clipping(data, mad_multiplier=2.0): """ @@ -76,6 +77,7 @@ def median_with_mad_clipping(data, mad_multiplier=2.0): median_clipped = jnp.median(clipped_data) return median_clipped + # Below are the jax functions for Gaussian Process regression. # Kernel is assumed to be isotropic RBF kernel, and solved # using exact Cholesky decomposition. @@ -88,6 +90,7 @@ def median_with_mad_clipping(data, mad_multiplier=2.0): # of custom things (setting my own hyperparameters, fine tune mean function, # dynamic binning, ...). + @jit def jax_pdist_squared(x): """ @@ -118,6 +121,7 @@ def jax_pdist_squared(x): """ return jnp.sum((x[:, None, :] - x[None, :, :]) ** 2, axis=-1) + @jit def jax_cdist_squared(xa, xb): """ @@ -152,6 +156,7 @@ def jax_cdist_squared(xa, xb): """ return jnp.sum((xa[:, None, :] - xb[None, :, :]) ** 2, axis=-1) + @jit def jax_rbf_kernel(x, sigma, correlation_length, y_err): """ @@ -175,10 +180,11 @@ def jax_rbf_kernel(x, sigma, correlation_length, y_err): """ distance_squared = jax_pdist_squared(x) kernel = (sigma**2) * jnp.exp(-0.5 * distance_squared / (correlation_length**2)) - y_err = jnp.ones(len(x[:,0])) * y_err + y_err = jnp.ones(len(x[:, 0])) * y_err kernel += jnp.eye(len(y_err)) * (y_err**2) return kernel + @jit def jax_rbf_kernel_rect(x1, x2, sigma, correlation_length): """ @@ -227,6 +233,7 @@ def jax_get_alpha(y, kernel): alpha = jax.scipy.linalg.cho_solve(factor, y, overwrite_b=False) return alpha.reshape((len(alpha), 1)) + @jit def jax_get_y_predict(kernel_rect, alpha): """ @@ -250,27 +257,24 @@ def jax_get_y_predict(kernel_rect, alpha): class GaussianProcessJax: def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): - self.std = std - self.l = correlation_length + self.correlation_lenght = correlation_length self.white_noise = white_noise self.mean = mean self._alpha = None def fit(self, x_good, y_good): - y = y_good - self.mean self._x = x_good - kernel = jax_rbf_kernel(x_good, self.std, self.l, self.white_noise) + kernel = jax_rbf_kernel(x_good, self.std, self.correlation_lenght, self.white_noise) self._alpha = jax_get_alpha(y, kernel) - def predict(self, x_bad): - - kernel_rect = jax_rbf_kernel_rect(x_bad, self._x, self.std, self.l) + kernel_rect = jax_rbf_kernel_rect(x_bad, self._x, self.std, self.correlation_lenght) y_pred = jax_get_y_predict(kernel_rect, self._alpha) return y_pred + self.mean + # Vanilla Gaussian Process regression using treegp package # There is no fancy O(N*log(N)) solver here, just the basic GP regression (Cholesky). class GaussianProcessTreegp: @@ -292,7 +296,7 @@ class GaussianProcessTreegp: -------- fit(x_good, y_good): Fit the Gaussian Process to the given training data. - + Parameters: ----------- x_good : array-like @@ -302,7 +306,7 @@ class GaussianProcessTreegp: predict(x_bad): Predict the target values for the given input features. - + Parameters: ----------- x_bad : array-like @@ -317,7 +321,7 @@ class GaussianProcessTreegp: def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): self.std = std - self.l = correlation_length + self.correlation_length = correlation_length self.white_noise = white_noise self.mean = mean @@ -332,7 +336,7 @@ def fit(self, x_good, y_good): y_good : array-like Target values for the training data. """ - kernel = f"{self.std}**2 * RBF({self.l})" + kernel = f"{self.std}**2 * RBF({self.correlation_length})" self.gp = treegp.GPInterpolation( kernel=kernel, optimizer="none", @@ -362,7 +366,8 @@ def predict(self, x_bad): class InterpolateOverDefectGaussianProcess: """ - InterpolateOverDefectGaussianProcess class performs Gaussian Process (GP) interpolation over defects in an image. + InterpolateOverDefectGaussianProcess class performs Gaussian Process + (GP) interpolation over defects in an image. Parameters: ----------- @@ -429,7 +434,6 @@ def __init__( threshold_subdivide=20000, correlation_length_cut=5, ): - if method == "jax": self.GaussianProcess = GaussianProcessJax elif method == "treegp": @@ -467,18 +471,18 @@ def interpolate_over_defects(self): # For now, we dilate by 5 times the correlation length # For GP with isotropic kernel, points at 5 correlation lengths away have negligible # effect on the prediction. - bbox = bbox.dilatedBy(int(self.correlation_length * self.correlation_length_cut)) # need integer as input. + bbox = bbox.dilatedBy( + int(self.correlation_length * self.correlation_length_cut) + ) # need integer as input. xmin, xmax = max([glob_xmin, bbox.minX]), min(glob_xmax, bbox.maxX) ymin, ymax = max([glob_ymin, bbox.minY]), min(glob_ymax, bbox.maxY) localBox = Box2I(Point2I(xmin, ymin), Extent2I(xmax - xmin, ymax - ymin)) try: sub_masked_image = self.maskedImage[localBox] - except: + except IndexError: raise ValueError("Sub-masked image not found.") - sub_masked_image = self.interpolate_sub_masked_image( - sub_masked_image - ) + sub_masked_image = self.interpolate_sub_masked_image(sub_masked_image) self.maskedImage[localBox] = sub_masked_image def _good_pixel_binning(self, good_pixel): @@ -496,17 +500,21 @@ def _good_pixel_binning(self, good_pixel): The binned array of good pixels. """ - n_pixels = len(good_pixel[:,0]) + n_pixels = len(good_pixel[:, 0]) dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning)) - if n_pixels/self.bin_spacing**2 < n_pixels/dynamic_binning**2: + if n_pixels / self.bin_spacing**2 < n_pixels / dynamic_binning**2: bin_spacing = self.bin_spacing else: bin_spacing = dynamic_binning - binning = treegp.meanify(bin_spacing=bin_spacing, statistics='mean') - binning.add_field(good_pixel[:, :2], good_pixel[:, 2:].T,) + binning = treegp.meanify(bin_spacing=bin_spacing, statistics="mean") + binning.add_field( + good_pixel[:, :2], + good_pixel[:, 2:].T, + ) binning.meanify() - return np.array([binning.coords0[:, 0], binning.coords0[:, 1], binning.params0]).T - + return np.array( + [binning.coords0[:, 0], binning.coords0[:, 1], binning.params0] + ).T def interpolate_sub_masked_image(self, sub_masked_image): """ @@ -523,13 +531,15 @@ def interpolate_sub_masked_image(self, sub_masked_image): The interpolated sub-masked image. """ - cut = int(self.correlation_length * self.correlation_length_cut) # need integer as input. + cut = int( + self.correlation_length * self.correlation_length_cut + ) # need integer as input. bad_pixel, good_pixel = ctUtils.findGoodPixelsAroundBadPixels( sub_masked_image, self.defects, buffer=cut ) # Do nothing if bad pixel is None. if bad_pixel.size == 0 or good_pixel.size == 0: - warnings.warn('No bad or good pixels found. No interpolation performed.') + warnings.warn("No bad or good pixels found. No interpolation performed.") return sub_masked_image # Do GP interpolation if bad pixel found. else: @@ -543,7 +553,9 @@ def interpolate_sub_masked_image(self, sub_masked_image): filter_finite = np.isfinite(good_pixel[:, 2:]).T[0] good_pixel = good_pixel[filter_finite] if good_pixel.size == 0: - warnings.warn('No bad or good pixels found. No interpolation performed.') + warnings.warn( + "No bad or good pixels found. No interpolation performed." + ) return sub_masked_image # kernel amplitude might be better described by maximum value of good pixel given # the data and not really a random gaussian field. @@ -552,8 +564,10 @@ def interpolate_sub_masked_image(self, sub_masked_image): kernel_amplitude = np.max(good_pixel[:, 2:]) try: good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel)) - except: - warnings.warn('Binning failed, use original good pixel array in interpolate over.') + except Exception: + warnings.warn( + "Binning failed, use original good pixel array in interpolate over." + ) # put this after binning as comupting median is O(n*log(n)) mean = median_with_mad_clipping(good_pixel[:, 2:]) @@ -569,19 +583,29 @@ def interpolate_sub_masked_image(self, sub_masked_image): gp_predict = gp.predict(bad_pixel[:, :2]) bad_pixel[:, 2:] = gp_predict.reshape(np.shape(bad_pixel[:, 2:])) else: - warnings.warn('sub-divide bad pixel array to avoid memory error.') + warnings.warn("sub-divide bad pixel array to avoid memory error.") for i in range(0, len(bad_pixel), self.threshold_subdivide): - end = min(i + self.threshold_subdivide, len(bad_pixel)) - gp_predict = gp.predict(bad_pixel[i : end, :2]) - bad_pixel[i : end, 2:] = gp_predict.reshape(np.shape(bad_pixel[i : end, 2:])) + end = min(i + self.threshold_subdivide, len(bad_pixel)) + gp_predict = gp.predict(bad_pixel[i:end, :2]) + bad_pixel[i:end, 2:] = gp_predict.reshape( + np.shape(bad_pixel[i:end, 2:]) + ) # update_value ctUtils.updateImageFromArray(sub_masked_image.image, bad_pixel) updateMaskFromArray(sub_masked_image.mask, bad_pixel, self.interpBit) return sub_masked_image - -def interpolateOverDefectsGP(image, fwhm, badList, method="treegp", bin_spacing=25, - threshold_dynamic_binning=1000, threshold_subdivide=20000): + + +def interpolateOverDefectsGP( + image, + fwhm, + badList, + method="treegp", + bin_spacing=25, + threshold_dynamic_binning=1000, + threshold_subdivide=20000, +): """ Interpolates over defects in an image using Gaussian Process interpolation. @@ -619,10 +643,15 @@ def interpolateOverDefectsGP(image, fwhm, badList, method="treegp", bin_spacing= """ if badList == [] or badList is None: - warnings.warn('WARNING: no defects found. No interpolation performed.') + warnings.warn("WARNING: no defects found. No interpolation performed.") return - gp = InterpolateOverDefectGaussianProcess(image, defects=badList, method=method, - fwhm=fwhm, bin_spacing=bin_spacing, - threshold_dynamic_binning=threshold_dynamic_binning, - threshold_subdivide=threshold_subdivide) + gp = InterpolateOverDefectGaussianProcess( + image, + defects=badList, + method=method, + fwhm=fwhm, + bin_spacing=bin_spacing, + threshold_dynamic_binning=threshold_dynamic_binning, + threshold_subdivide=threshold_subdivide, + ) gp.interpolate_over_defects() diff --git a/python/lsst/meas/algorithms/interp.py b/python/lsst/meas/algorithms/interp.py index 7ba692aa3..5fcdbcf96 100644 --- a/python/lsst/meas/algorithms/interp.py +++ b/python/lsst/meas/algorithms/interp.py @@ -1,13 +1,23 @@ from . import interpolateOverDefectsOld from . import interpolateOverDefectsGP -__all__ = ['interpolateOverDefects'] +__all__ = ["interpolateOverDefects"] -def interpolateOverDefects(image, psf, badList, fallbackValue=0.0, fwhm=1.0, - useFallbackValueAtEdge=False, useLegacyInterp=False, - maskNameList=None, **kwargs): +def interpolateOverDefects( + image, + psf, + badList, + fallbackValue=0.0, + fwhm=1.0, + useFallbackValueAtEdge=False, + useLegacyInterp=False, + maskNameList=None, + **kwargs +): if useLegacyInterp: - return interpolateOverDefectsOld(image, psf, badList, fallbackValue, useFallbackValueAtEdge) + return interpolateOverDefectsOld( + image, psf, badList, fallbackValue, useFallbackValueAtEdge + ) else: - return interpolateOverDefectsGP(image, fwhm, maskNameList, **kwargs) \ No newline at end of file + return interpolateOverDefectsGP(image, fwhm, maskNameList, **kwargs)