diff --git a/python/lsst/meas/algorithms/gp_interpolation.py b/python/lsst/meas/algorithms/gp_interpolation.py index 2e56d930c..6de217f9a 100644 --- a/python/lsst/meas/algorithms/gp_interpolation.py +++ b/python/lsst/meas/algorithms/gp_interpolation.py @@ -32,7 +32,11 @@ import logging -__all__ = ["InterpolateOverDefectGaussianProcess", "GaussianProcessJax", "GaussianProcessTreegp"] +__all__ = [ + "InterpolateOverDefectGaussianProcess", + "GaussianProcessJax", + "GaussianProcessTreegp", +] def updateMaskFromArray(mask, bad_pixel, interpBit): @@ -191,6 +195,7 @@ class GaussianProcessJax: Mean value of the Gaussian Process. Default is 0.0. """ + def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): self.std = std self.correlation_length = correlation_length @@ -217,7 +222,9 @@ def fit(self, x_train, y_train): self._alpha = jax_get_alpha(y, kernel) def predict(self, x_predict): - kernel_rect = jax_rbf_kernel(x_predict, self._x, self.std, self.correlation_length) + kernel_rect = jax_rbf_kernel( + x_predict, self._x, self.std, self.correlation_length + ) y_pred = jax_get_gp_predict(kernel_rect, self._alpha) return y_pred + self.mean diff --git a/tests/test_gp_interp.py b/tests/test_gp_interp.py index c87f61e77..19f5c716b 100644 --- a/tests/test_gp_interp.py +++ b/tests/test_gp_interp.py @@ -27,7 +27,11 @@ import lsst.utils.tests import lsst.geom import lsst.afw.image as afwImage -from lsst.meas.algorithms import InterpolateOverDefectGaussianProcess, GaussianProcessTreegp +from lsst.meas.algorithms import ( + InterpolateOverDefectGaussianProcess, + GaussianProcessTreegp, +) + def rbf_kernel(x1, x2, sigma, correlation_length): """ @@ -62,7 +66,7 @@ def setUp(self): npoints = 1000 self.std = 100 - self.correlation_length = 10. + self.correlation_length = 10.0 self.white_noise = 1e-5 x1 = np.random.uniform(0, 99, npoints) @@ -71,7 +75,7 @@ def setUp(self): kernel = rbf_kernel(coord1, coord1, self.std, self.correlation_length) kernel += np.eye(npoints) * self.white_noise**2 - + # Data augmentation. Create a gaussian random field # on a 100 * 100 is to slow. So generate 1e3 points # and then interpolate it with a GP to do data augmentation. @@ -84,8 +88,12 @@ def setUp(self): x2, x1 = np.meshgrid(x2, x1) coord2 = np.array([x1.reshape(-1), x2.reshape(-1)]).T - tgp = GaussianProcessTreegp(std=self.std, correlation_length=self.correlation_length, - white_noise=self.white_noise, mean=0.0) + tgp = GaussianProcessTreegp( + std=self.std, + correlation_length=self.correlation_length, + white_noise=self.white_noise, + mean=0.0, + ) tgp.fit(coord1, z1) z2 = tgp.predict(coord2) z2 = z2.reshape(100, 121) @@ -144,16 +152,18 @@ def test_interpolation(self, method: str): Code used to solve gaussian process. """ - gp = InterpolateOverDefectGaussianProcess(self.maskedimage, - defects=["BAD", "SAT", "CR", "EDGE"], - method=method, - fwhm=self.correlation_length, - bin_image=False, - bin_spacing=30, - threshold_dynamic_binning=1000, - threshold_subdivide=20000, - correlation_length_cut=5, - log=None,) + gp = InterpolateOverDefectGaussianProcess( + self.maskedimage, + defects=["BAD", "SAT", "CR", "EDGE"], + method=method, + fwhm=self.correlation_length, + bin_image=False, + bin_spacing=30, + threshold_dynamic_binning=1000, + threshold_subdivide=20000, + correlation_length_cut=5, + log=None, + ) gp.run()