diff --git a/python/lsst/meas/algorithms/gp_interpolation.py b/python/lsst/meas/algorithms/gp_interpolation.py index 7e67fae62..2e56d930c 100644 --- a/python/lsst/meas/algorithms/gp_interpolation.py +++ b/python/lsst/meas/algorithms/gp_interpolation.py @@ -21,7 +21,7 @@ import numpy as np from lsst.meas.algorithms import CloughTocher2DInterpolatorUtils as ctUtils -from lsst.geom import Box2I, Point2I, Extent2I +from lsst.geom import Box2I, Point2I from lsst.afw.geom import SpanSet import copy import treegp @@ -32,7 +32,7 @@ import logging -__all__ = ["InterpolateOverDefectGaussianProcess"] +__all__ = ["InterpolateOverDefectGaussianProcess", "GaussianProcessJax", "GaussianProcessTreegp"] def updateMaskFromArray(mask, bad_pixel, interpBit): @@ -329,6 +329,7 @@ def __init__( defects=["SAT"], method="treegp", fwhm=5, + bin_image=True, bin_spacing=10, threshold_dynamic_binning=1000, threshold_subdivide=20000, @@ -344,6 +345,7 @@ def __init__( self.log = log or logging.getLogger(__name__) + self.bin_image = bin_image self.bin_spacing = bin_spacing self.threshold_subdivide = threshold_subdivide self.threshold_dynamic_binning = threshold_dynamic_binning @@ -383,7 +385,7 @@ def run(self): ) # need integer as input. xmin, xmax = max([global_xmin, bbox.minX]), min(global_xmax, bbox.maxX) ymin, ymax = max([global_ymin, bbox.minY]), min(global_ymax, bbox.maxY) - localBox = Box2I(Point2I(xmin, ymin), Extent2I(xmax - xmin, ymax - ymin)) + localBox = Box2I(Point2I(xmin, ymin), Point2I(xmax - xmin, ymax - ymin)) masked_sub_image = self.masked_image[localBox] masked_sub_image = self.interpolate_masked_sub_image(masked_sub_image) @@ -464,12 +466,14 @@ def interpolate_masked_sub_image(self, masked_sub_image): # kernel amplitude might be better described by maximum value of good pixel given # the data and not really a random gaussian field. kernel_amplitude = np.max(good_pixel[:, 2:]) - try: - good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel)) - except Exception: - self.log.info( - "Binning failed, use original good pixel array in interpolation." - ) + + if self.bin_image: + try: + good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel)) + except Exception: + self.log.info( + "Binning failed, use original good pixel array in interpolation." + ) # put this after binning as computing median is O(n*log(n)) clipped_median = median_with_mad_clipping(good_pixel[:, 2:]) diff --git a/tests/test_gp_interp.py b/tests/test_gp_interp.py index 23c3819f3..c87f61e77 100644 --- a/tests/test_gp_interp.py +++ b/tests/test_gp_interp.py @@ -27,7 +27,31 @@ import lsst.utils.tests import lsst.geom import lsst.afw.image as afwImage -from lsst.meas.algorithms import InterpolateOverDefectGaussianProcess +from lsst.meas.algorithms import InterpolateOverDefectGaussianProcess, GaussianProcessTreegp + +def rbf_kernel(x1, x2, sigma, correlation_length): + """ + Computes the radial basis function (RBF) kernel matrix. + + Parameters: + ----------- + x1 : `np.array` + Location of training data point with shape (n_samples, n_features). + x2 : `np.array` + Location of training/test data point with shape (n_samples, n_features). + sigma : `float` + The scale parameter of the kernel. + correlation_length : `float` + The correlation length parameter of the kernel. + + Returns: + -------- + kernel : `np.array` + RBF kernel matrix with shape (n_samples, n_samples). + """ + distance_squared = np.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1) + kernel = (sigma**2) * np.exp(-0.5 * distance_squared / (correlation_length**2)) + return kernel class InterpolateOverDefectGaussianProcessTestCase(lsst.utils.tests.TestCase): @@ -36,10 +60,40 @@ class InterpolateOverDefectGaussianProcessTestCase(lsst.utils.tests.TestCase): def setUp(self): super().setUp() + npoints = 1000 + self.std = 100 + self.correlation_length = 10. + self.white_noise = 1e-5 + + x1 = np.random.uniform(0, 99, npoints) + x2 = np.random.uniform(0, 120, npoints) + coord1 = np.array([x1, x2]).T + + kernel = rbf_kernel(coord1, coord1, self.std, self.correlation_length) + kernel += np.eye(npoints) * self.white_noise**2 + + # Data augmentation. Create a gaussian random field + # on a 100 * 100 is to slow. So generate 1e3 points + # and then interpolate it with a GP to do data augmentation. + + np.random.seed(42) + z1 = np.random.multivariate_normal(np.zeros(npoints), kernel) + + x1 = np.linspace(0, 99, 100) + x2 = np.linspace(0, 120, 121) + x2, x1 = np.meshgrid(x2, x1) + coord2 = np.array([x1.reshape(-1), x2.reshape(-1)]).T + + tgp = GaussianProcessTreegp(std=self.std, correlation_length=self.correlation_length, + white_noise=self.white_noise, mean=0.0) + tgp.fit(coord1, z1) + z2 = tgp.predict(coord2) + z2 = z2.reshape(100, 121) + self.maskedimage = afwImage.MaskedImageF(100, 121) for x in range(100): for y in range(121): - self.maskedimage[x, y] = (3 * y + x * 5, 0, 1.0) + self.maskedimage[x, y] = (z2[x, y], 0, 1.0) # Clone the maskedimage so we can compare it after running the task. self.reference = self.maskedimage.clone() @@ -76,11 +130,11 @@ def setUp(self): self.reference.mask.array[:, :] = self.maskedimage.mask.array # Create a noise image - self.noise = self.maskedimage.clone() - np.random.seed(12345) - self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape) + # self.noise = self.maskedimage.clone() + # np.random.seed(12345) + # self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape) - @lsst.utils.tests.methodParameters(method=("treegp", "jax")) + @lsst.utils.tests.methodParameters(method=("jax")) def test_interpolation(self, method: str): """Test that the interpolation is done correctly. @@ -93,8 +147,9 @@ def test_interpolation(self, method: str): gp = InterpolateOverDefectGaussianProcess(self.maskedimage, defects=["BAD", "SAT", "CR", "EDGE"], method=method, - fwhm=5, - bin_spacing=10, + fwhm=self.correlation_length, + bin_image=False, + bin_spacing=30, threshold_dynamic_binning=1000, threshold_subdivide=20000, correlation_length_cut=5, @@ -111,15 +166,8 @@ def test_interpolation(self, method: str): self.assertImagesAlmostEqual( self.maskedimage.image[1:, :], self.reference.image[1:, :], - rtol=1e-05, - atol=1e-08, + atol=2, ) - - def test_dumb(self): - TEST = 42 - if TEST == 42: - raise ValueError('TEST == 42') - def setup_module(module):