diff --git a/python/lsst/meas/algorithms/gp_interpolation.py b/python/lsst/meas/algorithms/gp_interpolation.py index 1a08fbac5..7e67fae62 100644 --- a/python/lsst/meas/algorithms/gp_interpolation.py +++ b/python/lsst/meas/algorithms/gp_interpolation.py @@ -198,16 +198,26 @@ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): self.mean = mean self._alpha = None + # Looks weird to do that, but this is justified. + # in GP if no noise is provided, even if matrix + # can be inverted, it wont invert because of numerical + # issue (det(K)~0). Add a little bit of noise allow + # to compute a numerical solution in the case of no + # external noise is added. Wont happened on real + # image but help for unit test. + if self.white_noise == 0.0: + self.white_noise = 1e-5 + def fit(self, x_train, y_train): y = y_train - self.mean self._x = x_train - kernel = jax_rbf_kernel(x_train, self.std, self.correlation_length) + kernel = jax_rbf_kernel(x_train, x_train, self.std, self.correlation_length) y_err = jnp.ones(len(x_train[:, 0])) * self.white_noise kernel += jnp.eye(len(y_err)) * (y_err**2) self._alpha = jax_get_alpha(y, kernel) def predict(self, x_predict): - kernel_rect = jax_rbf_kernel(x_predict, self._x, self.std, self.correlation_length, 0) + kernel_rect = jax_rbf_kernel(x_predict, self._x, self.std, self.correlation_length) y_pred = jax_get_gp_predict(kernel_rect, self._alpha) return y_pred + self.mean @@ -236,6 +246,16 @@ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): self.white_noise = white_noise self.mean = mean + # Looks like weird to do that, but this is justified. + # in GP if no noise is provided, even if matrix + # can be inverted, it wont invert because of numerical + # issue (det(K)~0). Add a little bit of noise allow + # to compute a numerical solution in the case of no + # external noise is added. Wont happened on real + # image but help for unit test. + if self.white_noise == 0.0: + self.white_noise = 1e-5 + def fit(self, x_train, y_train): """ Fit the Gaussian Process to the given training data. diff --git a/tests/test_gp_interp.py b/tests/test_gp_interp.py index 9ebb31957..23c3819f3 100644 --- a/tests/test_gp_interp.py +++ b/tests/test_gp_interp.py @@ -22,17 +22,12 @@ import unittest -from typing import Iterable -from itertools import product import numpy as np import lsst.utils.tests import lsst.geom import lsst.afw.image as afwImage -from lsst.meas.algorithms.cloughTocher2DInterpolator import ( - CloughTocher2DInterpolateTask, -) -from lsst.meas.algorithms import CloughTocher2DInterpolatorUtils as ctUtils +from lsst.meas.algorithms import InterpolateOverDefectGaussianProcess class InterpolateOverDefectGaussianProcessTestCase(lsst.utils.tests.TestCase): @@ -85,84 +80,30 @@ def setUp(self): np.random.seed(12345) self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape) - @lsst.utils.tests.methodParameters(n_runs=(1, 2)) - def test_interpolation(self, n_runs: int): + @lsst.utils.tests.methodParameters(method=("treegp", "jax")) + def test_interpolation(self, method: str): """Test that the interpolation is done correctly. Parameters ---------- - n_runs : `int` - Number of times to run the task. Running the task more than once - should have no effect. + method : `str` + Code used to solve gaussian process. """ - config = CloughTocher2DInterpolateTask.ConfigClass() - config.badMaskPlanes = ( - "BAD", - "SAT", - "CR", - "EDGE", - ) - config.fillValue = 0.5 - task = CloughTocher2DInterpolateTask(config) - for n in range(n_runs): - task.run(self.maskedimage) - # Assert that the mask and the variance planes remain unchanged. - self.assertImagesEqual(self.maskedimage.variance, self.reference.variance) - self.assertMasksEqual(self.maskedimage.mask, self.reference.mask) + gp = InterpolateOverDefectGaussianProcess(self.maskedimage, + defects=["BAD", "SAT", "CR", "EDGE"], + method=method, + fwhm=5, + bin_spacing=10, + threshold_dynamic_binning=1000, + threshold_subdivide=20000, + correlation_length_cut=5, + log=None,) - # Check that the long streak of bad pixels have been replaced with the - # fillValue, but not the short streak. - np.testing.assert_array_equal(self.maskedimage.image[0:1, :].array, config.fillValue) - with self.assertRaises(AssertionError): - np.testing.assert_array_equal(self.maskedimage.image[25:28, 0:1].array, config.fillValue) + gp.run() - # Check that interpolated pixels are close to the reference (original), - # and that none of them is still NaN. - self.assertTrue(np.isfinite(self.maskedimage.image.array).all()) - self.assertImagesAlmostEqual( - self.maskedimage.image[1:, :], - self.reference.image[1:, :], - rtol=1e-05, - atol=1e-08, - ) - TEST = 42 - if TEST == 42: - raise ValueError('TEST == 42') - - @lsst.utils.tests.methodParametersProduct(pass_badpix=(True, False), pass_goodpix=(True, False)) - def test_interpolation_with_noise(self, pass_badpix: bool = True, pass_goodpix: bool = True): - """Test that we can reuse the badpix and goodpix. - - Parameters - ---------- - pass_badpix : `bool` - Whether to pass the badpix to the task? - pass_goodpix : `bool` - Whether to pass the goodpix to the task? - """ - - config = CloughTocher2DInterpolateTask.ConfigClass() - config.badMaskPlanes = ( - "BAD", - "SAT", - "CR", - "EDGE", - ) - task = CloughTocher2DInterpolateTask(config) - - badpix, goodpix = task.run(self.noise) - task.run( - self.maskedimage, - badpix=(badpix if pass_badpix else None), - goodpix=(goodpix if pass_goodpix else None), - ) - - # Check that the long streak of bad pixels by the edge have been - # replaced with fillValue, but not the short streak. - np.testing.assert_array_equal(self.maskedimage.image[0:1, :].array, config.fillValue) - with self.assertRaises(AssertionError): - np.testing.assert_array_equal(self.maskedimage.image[25:28, 0:1].array, config.fillValue) + # Assert that the mask and the variance planes remain unchanged. + self.assertImagesEqual(self.maskedimage.variance, self.reference.variance) # Check that interpolated pixels are close to the reference (original), # and that none of them is still NaN. @@ -173,9 +114,12 @@ def test_interpolation_with_noise(self, pass_badpix: bool = True, pass_goodpix: rtol=1e-05, atol=1e-08, ) + + def test_dumb(self): TEST = 42 if TEST == 42: raise ValueError('TEST == 42') + def setup_module(module):