diff --git a/python/lsst/meas/algorithms/__init__.py b/python/lsst/meas/algorithms/__init__.py
index 561e281d7..e750993d2 100644
--- a/python/lsst/meas/algorithms/__init__.py
+++ b/python/lsst/meas/algorithms/__init__.py
@@ -61,6 +61,8 @@
from .accumulator_mean_stack import *
from .scaleVariance import *
from .noise_covariance import *
+from .gp_interpolation import *
+from .interp import *
from .reinterpolate_pixels import *
from .setPrimaryFlags import *
from .coaddBoundedField import *
diff --git a/python/lsst/meas/algorithms/gp_interpolation.py b/python/lsst/meas/algorithms/gp_interpolation.py
new file mode 100644
index 000000000..6de217f9a
--- /dev/null
+++ b/python/lsst/meas/algorithms/gp_interpolation.py
@@ -0,0 +1,510 @@
+# This file is part of meas_algorithms.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import numpy as np
+from lsst.meas.algorithms import CloughTocher2DInterpolatorUtils as ctUtils
+from lsst.geom import Box2I, Point2I
+from lsst.afw.geom import SpanSet
+import copy
+import treegp
+
+import jax
+from jax import jit
+import jax.numpy as jnp
+
+import logging
+
+__all__ = [
+ "InterpolateOverDefectGaussianProcess",
+ "GaussianProcessJax",
+ "GaussianProcessTreegp",
+]
+
+
+def updateMaskFromArray(mask, bad_pixel, interpBit):
+ """
+ Update the mask array with the given bad pixels.
+
+ Parameters
+ ----------
+ mask : `lsst.afw.image.MaskedImage`
+ The mask image to update.
+ bad_pixel : `np.array`
+ An array-like object containing the coordinates of the bad pixels.
+ Each row should contain the x and y coordinates of a bad pixel.
+ interpBit : `int`
+ The bit value to set for the bad pixels in the mask.
+ """
+ x0 = mask.getX0()
+ y0 = mask.getY0()
+ for row in bad_pixel:
+ x = int(row[0] - x0)
+ y = int(row[1] - y0)
+ mask.array[y, x] |= interpBit
+ # TO DO --> might be better: mask.array[int(bad_pixel[:,1]-y0), int(bad_pixel[:,0]-x)] |= interpBit
+
+
+@jit
+def median_with_mad_clipping(data, mad_multiplier=2.0):
+ """
+ Calculate the median of the input data after applying Median Absolute Deviation (MAD) clipping.
+
+ The MAD clipping method is used to remove outliers from the data. The median of the data is calculated,
+ and then the MAD is calculated as the median absolute deviation from the median. The data is then clipped
+ by removing values that are outside the range of median +/- mad_multiplier * MAD. Finally, the median of
+ the clipped data is returned.
+
+ Parameters:
+ -----------
+ data : `np.array`
+ Input data array.
+ mad_multiplier : `float`, optional
+ Multiplier for the MAD value used for clipping. Default is 2.0.
+
+ Returns:
+ --------
+ median_clipped : `float`
+ Median value of the clipped data.
+
+ Examples:
+ ---------
+ >>> data = [1, 2, 3, 4, 5, 100]
+ >>> median_with_mad_clipping(data)
+ 3.5
+ """
+ median = jnp.median(data)
+ mad = jnp.median(jnp.abs(data - median))
+ clipping_range = mad_multiplier * mad
+ clipped_data = jnp.clip(data, median - clipping_range, median + clipping_range)
+ median_clipped = jnp.median(clipped_data)
+ return median_clipped
+
+
+@jit
+def jax_rbf_kernel(x1, x2, sigma, correlation_length):
+ """
+ Computes the radial basis function (RBF) kernel matrix.
+
+ Parameters:
+ -----------
+ x1 : `np.array`
+ Location of training data point with shape (n_samples, n_features).
+ x2 : `np.array`
+ Location of training/test data point with shape (n_samples, n_features).
+ sigma : `float`
+ The scale parameter of the kernel.
+ correlation_length : `float`
+ The correlation length parameter of the kernel.
+
+ Returns:
+ --------
+ kernel : `np.array`
+ RBF kernel matrix with shape (n_samples, n_samples).
+ """
+ distance_squared = jnp.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
+ kernel = (sigma**2) * jnp.exp(-0.5 * distance_squared / (correlation_length**2))
+ return kernel
+
+
+@jit
+def jax_get_alpha(y, kernel):
+ """
+ Compute the alpha vector for Gaussian Process interpolation.
+
+ Parameters:
+ -----------
+ y : `np.array`
+ The target values of the Gaussian Process.
+ kernel : `np.array`
+ The kernel matrix of the Gaussian Process.
+
+ Returns:
+ --------
+ alpha : `np.array`
+ The alpha vector computed using the Cholesky decomposition and solution.
+
+ """
+ factor = (jax.scipy.linalg.cholesky(kernel, overwrite_a=True, lower=False), False)
+ alpha = jax.scipy.linalg.cho_solve(factor, y, overwrite_b=False)
+ return alpha.reshape((len(alpha), 1))
+
+
+@jit
+def jax_get_gp_predict(kernel_rect, alpha):
+ """
+ Compute the predicted values of gp using the given kernel and alpha (cholesky solution).
+
+ Parameters:
+ -----------
+ kernel_rect : `np.array`
+ The kernel matrix.
+ alpha : `np.array`
+ The alpha vector from Cholesky solution.
+
+ Returns:
+ --------
+ `np.array`
+ The predicted values of y.
+
+ """
+ return jnp.dot(kernel_rect, alpha).T[0]
+
+
+class GaussianProcessJax:
+ """
+ Gaussian Process regression in JAX.
+ Kernel is assumed to be isotropic RBF kernel, and solved
+ using exact Cholesky decomposition.
+ The interpolation solution is obtained by solving the linear system:
+ y_interp = kernel_rect @ (kernel + y_err**2 * I)^-1 @ y_training.
+ See the Rasmussen and Williams book for more details.
+ Each function is decorated with @jit to compile the function.
+ Exist package like tinygp, that is implemented in jax also.
+ This class is a custom implementation
+ of Gaussian Processes, which allows for setting the hyperparameters,
+ fine-tuning the mean function,
+ and other specifications.
+
+ Parameters:
+ -----------
+ std : `float`, optional
+ Standard deviation of the Gaussian Process kernel. Default is 1.0.
+ correlation_length : `float`, optional
+ Correlation length of the Gaussian Process kernel. Default is 1.0.
+ white_noise : `float`, optional
+ White noise level of the Gaussian Process. Default is 0.0.
+ mean : `float`, optional
+ Mean value of the Gaussian Process. Default is 0.0.
+
+ """
+
+ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
+ self.std = std
+ self.correlation_length = correlation_length
+ self.white_noise = white_noise
+ self.mean = mean
+ self._alpha = None
+
+ # Looks weird to do that, but this is justified.
+ # in GP if no noise is provided, even if matrix
+ # can be inverted, it wont invert because of numerical
+ # issue (det(K)~0). Add a little bit of noise allow
+ # to compute a numerical solution in the case of no
+ # external noise is added. Wont happened on real
+ # image but help for unit test.
+ if self.white_noise == 0.0:
+ self.white_noise = 1e-5
+
+ def fit(self, x_train, y_train):
+ y = y_train - self.mean
+ self._x = x_train
+ kernel = jax_rbf_kernel(x_train, x_train, self.std, self.correlation_length)
+ y_err = jnp.ones(len(x_train[:, 0])) * self.white_noise
+ kernel += jnp.eye(len(y_err)) * (y_err**2)
+ self._alpha = jax_get_alpha(y, kernel)
+
+ def predict(self, x_predict):
+ kernel_rect = jax_rbf_kernel(
+ x_predict, self._x, self.std, self.correlation_length
+ )
+ y_pred = jax_get_gp_predict(kernel_rect, self._alpha)
+ return y_pred + self.mean
+
+
+class GaussianProcessTreegp:
+ """
+ Gaussian Process Treegp class for Gaussian Process interpolation.
+
+ The basic GP regression, which uses Cholesky decomposition.
+
+ Parameters:
+ -----------
+ std : `float`, optional
+ Standard deviation of the Gaussian Process kernel. Default is 1.0.
+ correlation_length : `float`, optional
+ Correlation length of the Gaussian Process kernel. Default is 1.0.
+ white_noise : `float`, optional
+ White noise level of the Gaussian Process. Default is 0.0.
+ mean : `float`, optional
+ Mean value of the Gaussian Process. Default is 0.0.
+ """
+
+ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
+ self.std = std
+ self.correlation_length = correlation_length
+ self.white_noise = white_noise
+ self.mean = mean
+
+ # Looks like weird to do that, but this is justified.
+ # in GP if no noise is provided, even if matrix
+ # can be inverted, it wont invert because of numerical
+ # issue (det(K)~0). Add a little bit of noise allow
+ # to compute a numerical solution in the case of no
+ # external noise is added. Wont happened on real
+ # image but help for unit test.
+ if self.white_noise == 0.0:
+ self.white_noise = 1e-5
+
+ def fit(self, x_train, y_train):
+ """
+ Fit the Gaussian Process to the given training data.
+
+ Parameters:
+ -----------
+ x_train : `np.array`
+ Input features for the training data.
+ y_train : `np.array`
+ Target values for the training data.
+ """
+ kernel = f"{self.std}**2 * RBF({self.correlation_length})"
+ self.gp = treegp.GPInterpolation(
+ kernel=kernel,
+ optimizer="none",
+ normalize=False,
+ white_noise=self.white_noise,
+ )
+ self.gp.initialize(x_train, y_train - self.mean)
+ self.gp.solve()
+
+ def predict(self, x_predict):
+ """
+ Predict the target values for the given input features.
+
+ Parameters:
+ -----------
+ x_predict : `np.array`
+ Input features for the prediction.
+
+ Returns:
+ --------
+ y_pred : `np.array`
+ Predicted target values.
+ """
+ y_pred = self.gp.predict(x_predict)
+ return y_pred + self.mean
+
+
+class InterpolateOverDefectGaussianProcess:
+ """
+ InterpolateOverDefectGaussianProcess class performs Gaussian Process
+ (GP) interpolation over defects in an image.
+
+ Parameters:
+ -----------
+ masked_image : `lsst.afw.image.MaskedImage`
+ The masked image containing the defects to be interpolated.
+ defects : `list`[`str`], optional
+ The types of defects to be interpolated. Default is ["SAT"].
+ method : `str`, optional
+ The method to use for GP interpolation. Must be either "jax" or "treegp". Default is "treegp".
+ fwhm : `float`, optional
+ The full width at half maximum (FWHM) of the PSF. Default is 5.
+ bin_spacing : `int`, optional
+ The spacing between bins for good pixel binning. Default is 10.
+ threshold_dynamic_binning : `int`, optional
+ The threshold for dynamic binning. Default is 1000.
+ threshold_subdivide : `int`, optional
+ The threshold for sub-dividing the bad pixel array to avoid memory error. Default is 20000.
+ correlation_length_cut : `int`, optional
+ The factor by which to dilate the bounding box around defects. Default is 5.
+ log : `lsst.log.Log`, `logging.Logger` or `None`, optional
+ Logger object used to write out messages. If `None` a default
+ logger will be used.
+ """
+
+ def __init__(
+ self,
+ masked_image,
+ defects=["SAT"],
+ method="treegp",
+ fwhm=5,
+ bin_image=True,
+ bin_spacing=10,
+ threshold_dynamic_binning=1000,
+ threshold_subdivide=20000,
+ correlation_length_cut=5,
+ log=None,
+ ):
+ if method == "jax":
+ self.GaussianProcess = GaussianProcessJax
+ elif method == "treegp":
+ self.GaussianProcess = GaussianProcessTreegp
+ else:
+ raise ValueError("Invalid method. Must be 'jax' or 'treegp'.")
+
+ self.log = log or logging.getLogger(__name__)
+
+ self.bin_image = bin_image
+ self.bin_spacing = bin_spacing
+ self.threshold_subdivide = threshold_subdivide
+ self.threshold_dynamic_binning = threshold_dynamic_binning
+
+ self.masked_image = masked_image
+ self.defects = defects
+ self.correlation_length = fwhm
+ self.correlation_length_cut = correlation_length_cut
+
+ self.interpBit = self.masked_image.mask.getPlaneBitMask("INTRP")
+
+ def run(self):
+ """
+ Interpolate over the defects in the image.
+
+ Change self.masked_image .
+ """
+ if self.defects == [] or self.defects is None:
+ self.log.info("No defects found. No interpolation performed.")
+ else:
+ mask = self.masked_image.getMask()
+ bad_pixel_mask = mask.getPlaneBitMask(self.defects)
+ bad_mask_span_set = SpanSet.fromMask(mask, bad_pixel_mask).split()
+
+ bbox = self.masked_image.getBBox()
+ global_xmin, global_xmax = bbox.minX, bbox.maxX
+ global_ymin, global_ymax = bbox.minY, bbox.maxY
+
+ for spanset in bad_mask_span_set:
+ bbox = spanset.getBBox()
+ # Dilate the bbox to make sure we have enough good pixels around the defect
+ # For now, we dilate by 5 times the correlation length
+ # For GP with the isotropic kernel, points at the default value of
+ # correlation_length_cut=5 have negligible effect on the prediction.
+ bbox = bbox.dilatedBy(
+ int(self.correlation_length * self.correlation_length_cut)
+ ) # need integer as input.
+ xmin, xmax = max([global_xmin, bbox.minX]), min(global_xmax, bbox.maxX)
+ ymin, ymax = max([global_ymin, bbox.minY]), min(global_ymax, bbox.maxY)
+ localBox = Box2I(Point2I(xmin, ymin), Point2I(xmax - xmin, ymax - ymin))
+ masked_sub_image = self.masked_image[localBox]
+
+ masked_sub_image = self.interpolate_masked_sub_image(masked_sub_image)
+ self.masked_image[localBox] = masked_sub_image
+
+ def _good_pixel_binning(self, pixels):
+ """
+ Performs pixel binning using treegp.meanify
+
+ Parameters:
+ -----------
+ pixels : `np.array`
+ The array of pixels.
+
+ Returns:
+ --------
+ `np.array`
+ The binned array of pixels.
+ """
+
+ n_pixels = len(pixels[:, 0])
+ dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning))
+ if n_pixels / self.bin_spacing**2 < n_pixels / dynamic_binning**2:
+ bin_spacing = self.bin_spacing
+ else:
+ bin_spacing = dynamic_binning
+ binning = treegp.meanify(bin_spacing=bin_spacing, statistics="mean")
+ binning.add_field(
+ pixels[:, :2],
+ pixels[:, 2:].T,
+ )
+ binning.meanify()
+ return np.array(
+ [binning.coords0[:, 0], binning.coords0[:, 1], binning.params0]
+ ).T
+
+ def interpolate_masked_sub_image(self, masked_sub_image):
+ """
+ Interpolate the masked sub-image.
+
+ Parameters:
+ -----------
+ masked_sub_image : `lsst.afw.image.MaskedImage`
+ The sub-masked image to be interpolated.
+
+ Returns:
+ --------
+ `lsst.afw.image.MaskedImage`
+ The interpolated sub-masked image.
+ """
+
+ cut = int(
+ self.correlation_length * self.correlation_length_cut
+ ) # need integer as input.
+ bad_pixel, good_pixel = ctUtils.findGoodPixelsAroundBadPixels(
+ masked_sub_image, self.defects, buffer=cut
+ )
+ # Do nothing if bad pixel is None.
+ if bad_pixel.size == 0 or good_pixel.size == 0:
+ self.log.info("No bad or good pixels found. No interpolation performed.")
+ return masked_sub_image
+ # Do GP interpolation if bad pixel found.
+ else:
+ # gp interpolation
+ sub_image_array = masked_sub_image.getVariance().array
+ white_noise = np.sqrt(
+ np.mean(sub_image_array[np.isfinite(sub_image_array)])
+ )
+ kernel_amplitude = np.max(good_pixel[:, 2:])
+ if not np.isfinite(kernel_amplitude):
+ filter_finite = np.isfinite(good_pixel[:, 2:]).T[0]
+ good_pixel = good_pixel[filter_finite]
+ if good_pixel.size == 0:
+ self.log.info(
+ "No bad or good pixels found. No interpolation performed."
+ )
+ return masked_sub_image
+ # kernel amplitude might be better described by maximum value of good pixel given
+ # the data and not really a random gaussian field.
+ kernel_amplitude = np.max(good_pixel[:, 2:])
+
+ if self.bin_image:
+ try:
+ good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel))
+ except Exception:
+ self.log.info(
+ "Binning failed, use original good pixel array in interpolation."
+ )
+
+ # put this after binning as computing median is O(n*log(n))
+ clipped_median = median_with_mad_clipping(good_pixel[:, 2:])
+
+ gp = self.GaussianProcess(
+ std=np.sqrt(kernel_amplitude),
+ correlation_length=self.correlation_length,
+ white_noise=white_noise,
+ mean=clipped_median,
+ )
+ gp.fit(good_pixel[:, :2], np.squeeze(good_pixel[:, 2:]))
+ if bad_pixel.size < self.threshold_subdivide:
+ gp_predict = gp.predict(bad_pixel[:, :2])
+ bad_pixel[:, 2:] = gp_predict.reshape(np.shape(bad_pixel[:, 2:]))
+ else:
+ self.log.info("sub-divide bad pixel array to avoid memory error.")
+ for i in range(0, len(bad_pixel), self.threshold_subdivide):
+ end = min(i + self.threshold_subdivide, len(bad_pixel))
+ gp_predict = gp.predict(bad_pixel[i:end, :2])
+ bad_pixel[i:end, 2:] = gp_predict.reshape(
+ np.shape(bad_pixel[i:end, 2:])
+ )
+
+ # Update values
+ ctUtils.updateImageFromArray(masked_sub_image.image, bad_pixel)
+ updateMaskFromArray(masked_sub_image.mask, bad_pixel, self.interpBit)
+ return masked_sub_image
diff --git a/python/lsst/meas/algorithms/interp.cc b/python/lsst/meas/algorithms/interp.cc
index 885226d40..142b41976 100644
--- a/python/lsst/meas/algorithms/interp.cc
+++ b/python/lsst/meas/algorithms/interp.cc
@@ -37,7 +37,7 @@ namespace {
template
void declareInterpolateOverDefects(py::module& mod) {
- mod.def("interpolateOverDefects",
+ mod.def("legacyInterpolateOverDefects",
interpolateOverDefects<
afw::image::MaskedImage>,
"image"_a, "psf"_a, "badList"_a, "fallBackValue"_a = 0.0, "useFallbackValueAtEdge"_a = false);
diff --git a/python/lsst/meas/algorithms/interp.py b/python/lsst/meas/algorithms/interp.py
new file mode 100644
index 000000000..554a3cf92
--- /dev/null
+++ b/python/lsst/meas/algorithms/interp.py
@@ -0,0 +1,25 @@
+from . import legacyInterpolateOverDefects
+from . import InterpolateOverDefectGaussianProcess
+
+__all__ = ["interpolateOverDefects"]
+
+
+def interpolateOverDefects(
+ image,
+ psf,
+ badList,
+ fallbackValue=0.0,
+ useFallbackValueAtEdge=False,
+ fwhm=1.0,
+ useLegacyInterp=True,
+ maskNameList=None,
+ **kwargs
+):
+ if useLegacyInterp:
+ return legacyInterpolateOverDefects(
+ image, psf, badList, fallbackValue, useFallbackValueAtEdge
+ )
+ else:
+ gp = InterpolateOverDefectGaussianProcess(image, fwhm=fwhm,
+ defects=maskNameList, **kwargs)
+ return gp.run()
diff --git a/python/lsst/meas/algorithms/reinterpolate_pixels.py b/python/lsst/meas/algorithms/reinterpolate_pixels.py
index b04c2ccd3..a8719f94d 100644
--- a/python/lsst/meas/algorithms/reinterpolate_pixels.py
+++ b/python/lsst/meas/algorithms/reinterpolate_pixels.py
@@ -27,7 +27,7 @@
import lsst.afw.math as afwMath
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
-from lsst.meas.algorithms import Defect, interpolateOverDefects
+from . import Defect, interpolateOverDefects
class ReinterpolatePixelsConfig(pexConfig.Config):
diff --git a/tests/test_gp_interp.py b/tests/test_gp_interp.py
new file mode 100644
index 000000000..19f5c716b
--- /dev/null
+++ b/tests/test_gp_interp.py
@@ -0,0 +1,193 @@
+# This file is part of meas_algorithms.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+import unittest
+
+import numpy as np
+
+import lsst.utils.tests
+import lsst.geom
+import lsst.afw.image as afwImage
+from lsst.meas.algorithms import (
+ InterpolateOverDefectGaussianProcess,
+ GaussianProcessTreegp,
+)
+
+
+def rbf_kernel(x1, x2, sigma, correlation_length):
+ """
+ Computes the radial basis function (RBF) kernel matrix.
+
+ Parameters:
+ -----------
+ x1 : `np.array`
+ Location of training data point with shape (n_samples, n_features).
+ x2 : `np.array`
+ Location of training/test data point with shape (n_samples, n_features).
+ sigma : `float`
+ The scale parameter of the kernel.
+ correlation_length : `float`
+ The correlation length parameter of the kernel.
+
+ Returns:
+ --------
+ kernel : `np.array`
+ RBF kernel matrix with shape (n_samples, n_samples).
+ """
+ distance_squared = np.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
+ kernel = (sigma**2) * np.exp(-0.5 * distance_squared / (correlation_length**2))
+ return kernel
+
+
+class InterpolateOverDefectGaussianProcessTestCase(lsst.utils.tests.TestCase):
+ """Test InterpolateOverDefectGaussianProcess."""
+
+ def setUp(self):
+ super().setUp()
+
+ npoints = 1000
+ self.std = 100
+ self.correlation_length = 10.0
+ self.white_noise = 1e-5
+
+ x1 = np.random.uniform(0, 99, npoints)
+ x2 = np.random.uniform(0, 120, npoints)
+ coord1 = np.array([x1, x2]).T
+
+ kernel = rbf_kernel(coord1, coord1, self.std, self.correlation_length)
+ kernel += np.eye(npoints) * self.white_noise**2
+
+ # Data augmentation. Create a gaussian random field
+ # on a 100 * 100 is to slow. So generate 1e3 points
+ # and then interpolate it with a GP to do data augmentation.
+
+ np.random.seed(42)
+ z1 = np.random.multivariate_normal(np.zeros(npoints), kernel)
+
+ x1 = np.linspace(0, 99, 100)
+ x2 = np.linspace(0, 120, 121)
+ x2, x1 = np.meshgrid(x2, x1)
+ coord2 = np.array([x1.reshape(-1), x2.reshape(-1)]).T
+
+ tgp = GaussianProcessTreegp(
+ std=self.std,
+ correlation_length=self.correlation_length,
+ white_noise=self.white_noise,
+ mean=0.0,
+ )
+ tgp.fit(coord1, z1)
+ z2 = tgp.predict(coord2)
+ z2 = z2.reshape(100, 121)
+
+ self.maskedimage = afwImage.MaskedImageF(100, 121)
+ for x in range(100):
+ for y in range(121):
+ self.maskedimage[x, y] = (z2[x, y], 0, 1.0)
+
+ # Clone the maskedimage so we can compare it after running the task.
+ self.reference = self.maskedimage.clone()
+
+ # Set some central pixels as SAT
+ sliceX, sliceY = slice(30, 35), slice(40, 45)
+ self.maskedimage.mask[sliceX, sliceY] = afwImage.Mask.getPlaneBitMask("SAT")
+ self.maskedimage.image[sliceX, sliceY] = np.nan
+ # Put nans here to make sure interp is done ok
+
+ # Set an entire column as BAD
+ self.maskedimage.mask[54:55, :] = afwImage.Mask.getPlaneBitMask("BAD")
+ self.maskedimage.image[54:55, :] = np.nan
+
+ # Set an entire row as BAD
+ self.maskedimage.mask[:, 110:111] = afwImage.Mask.getPlaneBitMask("BAD")
+ self.maskedimage.image[:, 110:111] = np.nan
+
+ # Set a diagonal set of pixels as CR
+ for i in range(74, 78):
+ self.maskedimage.mask[i, i] = afwImage.Mask.getPlaneBitMask("CR")
+ self.maskedimage.image[i, i] = np.nan
+
+ # Set one of the edges as EDGE
+ self.maskedimage.mask[0:1, :] = afwImage.Mask.getPlaneBitMask("EDGE")
+ self.maskedimage.image[0:1, :] = np.nan
+
+ # Set a smaller streak at the edge
+ self.maskedimage.mask[25:28, 0:1] = afwImage.Mask.getPlaneBitMask("EDGE")
+ self.maskedimage.image[25:28, 0:1] = np.nan
+
+ # Update the reference image's mask alone, so we can compare them after
+ # running the task.
+ self.reference.mask.array[:, :] = self.maskedimage.mask.array
+
+ # Create a noise image
+ # self.noise = self.maskedimage.clone()
+ # np.random.seed(12345)
+ # self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape)
+
+ @lsst.utils.tests.methodParameters(method=("jax"))
+ def test_interpolation(self, method: str):
+ """Test that the interpolation is done correctly.
+
+ Parameters
+ ----------
+ method : `str`
+ Code used to solve gaussian process.
+ """
+
+ gp = InterpolateOverDefectGaussianProcess(
+ self.maskedimage,
+ defects=["BAD", "SAT", "CR", "EDGE"],
+ method=method,
+ fwhm=self.correlation_length,
+ bin_image=False,
+ bin_spacing=30,
+ threshold_dynamic_binning=1000,
+ threshold_subdivide=20000,
+ correlation_length_cut=5,
+ log=None,
+ )
+
+ gp.run()
+
+ # Assert that the mask and the variance planes remain unchanged.
+ self.assertImagesEqual(self.maskedimage.variance, self.reference.variance)
+
+ # Check that interpolated pixels are close to the reference (original),
+ # and that none of them is still NaN.
+ self.assertTrue(np.isfinite(self.maskedimage.image.array).all())
+ self.assertImagesAlmostEqual(
+ self.maskedimage.image[1:, :],
+ self.reference.image[1:, :],
+ atol=2,
+ )
+
+
+def setup_module(module):
+ lsst.utils.tests.init()
+
+
+class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()