Skip to content

Commit

Permalink
have unit test running on gp
Browse files Browse the repository at this point in the history
  • Loading branch information
PFLeget committed Sep 19, 2024
1 parent e5e7974 commit c4bd940
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 78 deletions.
24 changes: 22 additions & 2 deletions python/lsst/meas/algorithms/gp_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,26 @@ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
self.mean = mean
self._alpha = None

# Looks weird to do that, but this is justified.
# in GP if no noise is provided, even if matrix
# can be inverted, it wont invert because of numerical
# issue (det(K)~0). Add a little bit of noise allow
# to compute a numerical solution in the case of no
# external noise is added. Wont happened on real
# image but help for unit test.
if self.white_noise == 0.0:
self.white_noise = 1e-5

def fit(self, x_train, y_train):
y = y_train - self.mean
self._x = x_train
kernel = jax_rbf_kernel(x_train, self.std, self.correlation_length)
kernel = jax_rbf_kernel(x_train, x_train, self.std, self.correlation_length)
y_err = jnp.ones(len(x_train[:, 0])) * self.white_noise
kernel += jnp.eye(len(y_err)) * (y_err**2)
self._alpha = jax_get_alpha(y, kernel)

def predict(self, x_predict):
kernel_rect = jax_rbf_kernel(x_predict, self._x, self.std, self.correlation_length, 0)
kernel_rect = jax_rbf_kernel(x_predict, self._x, self.std, self.correlation_length)
y_pred = jax_get_gp_predict(kernel_rect, self._alpha)
return y_pred + self.mean

Expand Down Expand Up @@ -236,6 +246,16 @@ def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
self.white_noise = white_noise
self.mean = mean

# Looks like weird to do that, but this is justified.
# in GP if no noise is provided, even if matrix
# can be inverted, it wont invert because of numerical
# issue (det(K)~0). Add a little bit of noise allow
# to compute a numerical solution in the case of no
# external noise is added. Wont happened on real
# image but help for unit test.
if self.white_noise == 0.0:
self.white_noise = 1e-5

def fit(self, x_train, y_train):
"""
Fit the Gaussian Process to the given training data.
Expand Down
96 changes: 20 additions & 76 deletions tests/test_gp_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,12 @@

import unittest

from typing import Iterable
from itertools import product
import numpy as np

import lsst.utils.tests
import lsst.geom
import lsst.afw.image as afwImage
from lsst.meas.algorithms.cloughTocher2DInterpolator import (
CloughTocher2DInterpolateTask,
)
from lsst.meas.algorithms import CloughTocher2DInterpolatorUtils as ctUtils
from lsst.meas.algorithms import InterpolateOverDefectGaussianProcess


class InterpolateOverDefectGaussianProcessTestCase(lsst.utils.tests.TestCase):
Expand Down Expand Up @@ -85,84 +80,30 @@ def setUp(self):
np.random.seed(12345)
self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape)

@lsst.utils.tests.methodParameters(n_runs=(1, 2))
def test_interpolation(self, n_runs: int):
@lsst.utils.tests.methodParameters(method=("treegp", "jax"))
def test_interpolation(self, method: str):
"""Test that the interpolation is done correctly.
Parameters
----------
n_runs : `int`
Number of times to run the task. Running the task more than once
should have no effect.
method : `str`
Code used to solve gaussian process.
"""
config = CloughTocher2DInterpolateTask.ConfigClass()
config.badMaskPlanes = (
"BAD",
"SAT",
"CR",
"EDGE",
)
config.fillValue = 0.5
task = CloughTocher2DInterpolateTask(config)
for n in range(n_runs):
task.run(self.maskedimage)

# Assert that the mask and the variance planes remain unchanged.
self.assertImagesEqual(self.maskedimage.variance, self.reference.variance)
self.assertMasksEqual(self.maskedimage.mask, self.reference.mask)
gp = InterpolateOverDefectGaussianProcess(self.maskedimage,

Check failure on line 93 in tests/test_gp_interp.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

W291

trailing whitespace
defects=["BAD", "SAT", "CR", "EDGE"],
method=method,
fwhm=5,
bin_spacing=10,
threshold_dynamic_binning=1000,
threshold_subdivide=20000,
correlation_length_cut=5,
log=None,)

# Check that the long streak of bad pixels have been replaced with the
# fillValue, but not the short streak.
np.testing.assert_array_equal(self.maskedimage.image[0:1, :].array, config.fillValue)
with self.assertRaises(AssertionError):
np.testing.assert_array_equal(self.maskedimage.image[25:28, 0:1].array, config.fillValue)
gp.run()

# Check that interpolated pixels are close to the reference (original),
# and that none of them is still NaN.
self.assertTrue(np.isfinite(self.maskedimage.image.array).all())
self.assertImagesAlmostEqual(
self.maskedimage.image[1:, :],
self.reference.image[1:, :],
rtol=1e-05,
atol=1e-08,
)
TEST = 42
if TEST == 42:
raise ValueError('TEST == 42')

@lsst.utils.tests.methodParametersProduct(pass_badpix=(True, False), pass_goodpix=(True, False))
def test_interpolation_with_noise(self, pass_badpix: bool = True, pass_goodpix: bool = True):
"""Test that we can reuse the badpix and goodpix.
Parameters
----------
pass_badpix : `bool`
Whether to pass the badpix to the task?
pass_goodpix : `bool`
Whether to pass the goodpix to the task?
"""

config = CloughTocher2DInterpolateTask.ConfigClass()
config.badMaskPlanes = (
"BAD",
"SAT",
"CR",
"EDGE",
)
task = CloughTocher2DInterpolateTask(config)

badpix, goodpix = task.run(self.noise)
task.run(
self.maskedimage,
badpix=(badpix if pass_badpix else None),
goodpix=(goodpix if pass_goodpix else None),
)

# Check that the long streak of bad pixels by the edge have been
# replaced with fillValue, but not the short streak.
np.testing.assert_array_equal(self.maskedimage.image[0:1, :].array, config.fillValue)
with self.assertRaises(AssertionError):
np.testing.assert_array_equal(self.maskedimage.image[25:28, 0:1].array, config.fillValue)
# Assert that the mask and the variance planes remain unchanged.
self.assertImagesEqual(self.maskedimage.variance, self.reference.variance)

# Check that interpolated pixels are close to the reference (original),
# and that none of them is still NaN.
Expand All @@ -173,9 +114,12 @@ def test_interpolation_with_noise(self, pass_badpix: bool = True, pass_goodpix:
rtol=1e-05,
atol=1e-08,
)

Check failure on line 117 in tests/test_gp_interp.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

W293

blank line contains whitespace
def test_dumb(self):
TEST = 42
if TEST == 42:
raise ValueError('TEST == 42')

Check failure on line 122 in tests/test_gp_interp.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

W293

blank line contains whitespace


def setup_module(module):

Check failure on line 125 in tests/test_gp_interp.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E303

too many blank lines (3)
Expand Down

0 comments on commit c4bd940

Please sign in to comment.